1use std::sync::Arc;
9use std::{collections::HashMap, ptr::NonNull};
10
11use arrow_array::{
12 cast::AsArray, Array, ArrayRef, ArrowNumericType, FixedSizeBinaryArray, FixedSizeListArray,
13 GenericListArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, StructArray, UInt32Array,
14 UInt8Array,
15};
16use arrow_array::{
17 new_null_array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array,
18};
19use arrow_buffer::MutableBuffer;
20use arrow_data::ArrayDataBuilder;
21use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema};
22use arrow_select::{interleave::interleave, take::take};
23use rand::prelude::*;
24
25pub mod deepcopy;
26pub mod schema;
27pub use schema::*;
28pub mod bfloat16;
29pub mod floats;
30pub use floats::*;
31pub mod cast;
32pub mod list;
33pub mod memory;
34
35type Result<T> = std::result::Result<T, ArrowError>;
36
37pub trait DataTypeExt {
38 fn is_binary_like(&self) -> bool;
51
52 fn is_struct(&self) -> bool;
54
55 fn is_fixed_stride(&self) -> bool;
60
61 fn is_dictionary(&self) -> bool;
63
64 fn byte_width(&self) -> usize;
67
68 fn byte_width_opt(&self) -> Option<usize>;
71}
72
73impl DataTypeExt for DataType {
74 fn is_binary_like(&self) -> bool {
75 use DataType::*;
76 matches!(self, Utf8 | Binary | LargeUtf8 | LargeBinary)
77 }
78
79 fn is_struct(&self) -> bool {
80 matches!(self, Self::Struct(_))
81 }
82
83 fn is_fixed_stride(&self) -> bool {
84 use DataType::*;
85 matches!(
86 self,
87 Boolean
88 | UInt8
89 | UInt16
90 | UInt32
91 | UInt64
92 | Int8
93 | Int16
94 | Int32
95 | Int64
96 | Float16
97 | Float32
98 | Float64
99 | Decimal128(_, _)
100 | Decimal256(_, _)
101 | FixedSizeList(_, _)
102 | FixedSizeBinary(_)
103 | Duration(_)
104 | Timestamp(_, _)
105 | Date32
106 | Date64
107 | Time32(_)
108 | Time64(_)
109 )
110 }
111
112 fn is_dictionary(&self) -> bool {
113 matches!(self, Self::Dictionary(_, _))
114 }
115
116 fn byte_width_opt(&self) -> Option<usize> {
117 match self {
118 Self::Int8 => Some(1),
119 Self::Int16 => Some(2),
120 Self::Int32 => Some(4),
121 Self::Int64 => Some(8),
122 Self::UInt8 => Some(1),
123 Self::UInt16 => Some(2),
124 Self::UInt32 => Some(4),
125 Self::UInt64 => Some(8),
126 Self::Float16 => Some(2),
127 Self::Float32 => Some(4),
128 Self::Float64 => Some(8),
129 Self::Date32 => Some(4),
130 Self::Date64 => Some(8),
131 Self::Time32(_) => Some(4),
132 Self::Time64(_) => Some(8),
133 Self::Timestamp(_, _) => Some(8),
134 Self::Duration(_) => Some(8),
135 Self::Decimal128(_, _) => Some(16),
136 Self::Decimal256(_, _) => Some(32),
137 Self::Interval(unit) => match unit {
138 IntervalUnit::YearMonth => Some(4),
139 IntervalUnit::DayTime => Some(8),
140 IntervalUnit::MonthDayNano => Some(16),
141 },
142 Self::FixedSizeBinary(s) => Some(*s as usize),
143 Self::FixedSizeList(dt, s) => Some(*s as usize * dt.data_type().byte_width()),
144 _ => None,
145 }
146 }
147
148 fn byte_width(&self) -> usize {
149 self.byte_width_opt()
150 .unwrap_or_else(|| panic!("Expecting fixed stride data type, found {:?}", self))
151 }
152}
153
154pub fn try_new_generic_list_array<T: Array, Offset: ArrowNumericType>(
172 values: T,
173 offsets: &PrimitiveArray<Offset>,
174) -> Result<GenericListArray<Offset::Native>>
175where
176 Offset::Native: OffsetSizeTrait,
177{
178 let data_type = if Offset::Native::IS_LARGE {
179 DataType::LargeList(Arc::new(Field::new(
180 "item",
181 values.data_type().clone(),
182 true,
183 )))
184 } else {
185 DataType::List(Arc::new(Field::new(
186 "item",
187 values.data_type().clone(),
188 true,
189 )))
190 };
191 let data = ArrayDataBuilder::new(data_type)
192 .len(offsets.len() - 1)
193 .add_buffer(offsets.into_data().buffers()[0].clone())
194 .add_child_data(values.into_data())
195 .build()?;
196
197 Ok(GenericListArray::from(data))
198}
199
200pub fn fixed_size_list_type(list_width: i32, inner_type: DataType) -> DataType {
201 DataType::FixedSizeList(Arc::new(Field::new("item", inner_type, true)), list_width)
202}
203
204pub trait FixedSizeListArrayExt {
205 fn try_new_from_values<T: Array + 'static>(
224 values: T,
225 list_size: i32,
226 ) -> Result<FixedSizeListArray>;
227
228 fn sample(&self, n: usize) -> Result<FixedSizeListArray>;
242
243 fn convert_to_floating_point(&self) -> Result<FixedSizeListArray>;
246}
247
248impl FixedSizeListArrayExt for FixedSizeListArray {
249 fn try_new_from_values<T: Array + 'static>(values: T, list_size: i32) -> Result<Self> {
250 let field = Arc::new(Field::new("item", values.data_type().clone(), true));
251 let values = Arc::new(values);
252
253 Self::try_new(field, list_size, values, None)
254 }
255
256 fn sample(&self, n: usize) -> Result<FixedSizeListArray> {
257 if n >= self.len() {
258 return Ok(self.clone());
259 }
260 let mut rng = SmallRng::from_entropy();
261 let chosen = (0..self.len() as u32).choose_multiple(&mut rng, n);
262 take(self, &UInt32Array::from(chosen), None).map(|arr| arr.as_fixed_size_list().clone())
263 }
264
265 fn convert_to_floating_point(&self) -> Result<FixedSizeListArray> {
266 match self.data_type() {
267 DataType::FixedSizeList(field, size) => match field.data_type() {
268 DataType::Float16 | DataType::Float32 | DataType::Float64 => Ok(self.clone()),
269 DataType::Int8 => Ok(Self::new(
270 Arc::new(arrow_schema::Field::new(
271 field.name(),
272 DataType::Float32,
273 field.is_nullable(),
274 )),
275 *size,
276 Arc::new(Float32Array::from_iter_values(
277 self.values()
278 .as_any()
279 .downcast_ref::<Int8Array>()
280 .ok_or(ArrowError::ParseError(
281 "Fail to cast primitive array to Int8Type".to_string(),
282 ))?
283 .into_iter()
284 .filter_map(|x| x.map(|y| y as f32)),
285 )),
286 self.nulls().cloned(),
287 )),
288 DataType::Int16 => Ok(Self::new(
289 Arc::new(arrow_schema::Field::new(
290 field.name(),
291 DataType::Float32,
292 field.is_nullable(),
293 )),
294 *size,
295 Arc::new(Float32Array::from_iter_values(
296 self.values()
297 .as_any()
298 .downcast_ref::<Int16Array>()
299 .ok_or(ArrowError::ParseError(
300 "Fail to cast primitive array to Int8Type".to_string(),
301 ))?
302 .into_iter()
303 .filter_map(|x| x.map(|y| y as f32)),
304 )),
305 self.nulls().cloned(),
306 )),
307 DataType::Int32 => Ok(Self::new(
308 Arc::new(arrow_schema::Field::new(
309 field.name(),
310 DataType::Float32,
311 field.is_nullable(),
312 )),
313 *size,
314 Arc::new(Float32Array::from_iter_values(
315 self.values()
316 .as_any()
317 .downcast_ref::<Int32Array>()
318 .ok_or(ArrowError::ParseError(
319 "Fail to cast primitive array to Int8Type".to_string(),
320 ))?
321 .into_iter()
322 .filter_map(|x| x.map(|y| y as f32)),
323 )),
324 self.nulls().cloned(),
325 )),
326 DataType::Int64 => Ok(Self::new(
327 Arc::new(arrow_schema::Field::new(
328 field.name(),
329 DataType::Float64,
330 field.is_nullable(),
331 )),
332 *size,
333 Arc::new(Float64Array::from_iter_values(
334 self.values()
335 .as_any()
336 .downcast_ref::<Int64Array>()
337 .ok_or(ArrowError::ParseError(
338 "Fail to cast primitive array to Int8Type".to_string(),
339 ))?
340 .into_iter()
341 .filter_map(|x| x.map(|y| y as f64)),
342 )),
343 self.nulls().cloned(),
344 )),
345 DataType::UInt8 => Ok(Self::new(
346 Arc::new(arrow_schema::Field::new(
347 field.name(),
348 DataType::Float64,
349 field.is_nullable(),
350 )),
351 *size,
352 Arc::new(Float64Array::from_iter_values(
353 self.values()
354 .as_any()
355 .downcast_ref::<UInt8Array>()
356 .ok_or(ArrowError::ParseError(
357 "Fail to cast primitive array to Int8Type".to_string(),
358 ))?
359 .into_iter()
360 .filter_map(|x| x.map(|y| y as f64)),
361 )),
362 self.nulls().cloned(),
363 )),
364 DataType::UInt32 => Ok(Self::new(
365 Arc::new(arrow_schema::Field::new(
366 field.name(),
367 DataType::Float64,
368 field.is_nullable(),
369 )),
370 *size,
371 Arc::new(Float64Array::from_iter_values(
372 self.values()
373 .as_any()
374 .downcast_ref::<UInt32Array>()
375 .ok_or(ArrowError::ParseError(
376 "Fail to cast primitive array to Int8Type".to_string(),
377 ))?
378 .into_iter()
379 .filter_map(|x| x.map(|y| y as f64)),
380 )),
381 self.nulls().cloned(),
382 )),
383 data_type => Err(ArrowError::ParseError(format!(
384 "Expect either floating type or integer got {:?}",
385 data_type
386 ))),
387 },
388 data_type => Err(ArrowError::ParseError(format!(
389 "Expect either FixedSizeList got {:?}",
390 data_type
391 ))),
392 }
393 }
394}
395
396pub fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
399 arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
400}
401
402pub trait FixedSizeBinaryArrayExt {
403 fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<FixedSizeBinaryArray>;
422}
423
424impl FixedSizeBinaryArrayExt for FixedSizeBinaryArray {
425 fn try_new_from_values(values: &UInt8Array, stride: i32) -> Result<Self> {
426 let data_type = DataType::FixedSizeBinary(stride);
427 let data = ArrayDataBuilder::new(data_type)
428 .len(values.len() / stride as usize)
429 .add_buffer(values.into_data().buffers()[0].clone())
430 .build()?;
431 Ok(Self::from(data))
432 }
433}
434
435pub fn as_fixed_size_binary_array(arr: &dyn Array) -> &FixedSizeBinaryArray {
436 arr.as_any().downcast_ref::<FixedSizeBinaryArray>().unwrap()
437}
438
439pub fn iter_str_array(arr: &dyn Array) -> Box<dyn Iterator<Item = Option<&str>> + '_> {
440 match arr.data_type() {
441 DataType::Utf8 => Box::new(arr.as_string::<i32>().iter()),
442 DataType::LargeUtf8 => Box::new(arr.as_string::<i64>().iter()),
443 _ => panic!("Expecting Utf8 or LargeUtf8, found {:?}", arr.data_type()),
444 }
445}
446
447pub trait RecordBatchExt {
449 fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
479
480 fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<RecordBatch>;
482
483 fn try_new_from_struct_array(&self, arr: StructArray) -> Result<RecordBatch>;
487
488 fn merge(&self, other: &RecordBatch) -> Result<RecordBatch>;
533
534 fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch>;
544
545 fn drop_column(&self, name: &str) -> Result<RecordBatch>;
549
550 fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch>;
552
553 fn replace_column_schema_by_name(
555 &self,
556 name: &str,
557 new_data_type: DataType,
558 column: Arc<dyn Array>,
559 ) -> Result<RecordBatch>;
560
561 fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef>;
563
564 fn project_by_schema(&self, schema: &Schema) -> Result<RecordBatch>;
566
567 fn metadata(&self) -> &HashMap<String, String>;
569
570 fn add_metadata(&self, key: String, value: String) -> Result<RecordBatch> {
572 let mut metadata = self.metadata().clone();
573 metadata.insert(key, value);
574 self.with_metadata(metadata)
575 }
576
577 fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch>;
579
580 fn take(&self, indices: &UInt32Array) -> Result<RecordBatch>;
582}
583
584impl RecordBatchExt for RecordBatch {
585 fn try_with_column(&self, field: Field, arr: ArrayRef) -> Result<Self> {
586 let new_schema = Arc::new(self.schema().as_ref().try_with_column(field)?);
587 let mut new_columns = self.columns().to_vec();
588 new_columns.push(arr);
589 Self::try_new(new_schema, new_columns)
590 }
591
592 fn try_with_column_at(&self, index: usize, field: Field, arr: ArrayRef) -> Result<Self> {
593 let new_schema = Arc::new(self.schema().as_ref().try_with_column_at(index, field)?);
594 let mut new_columns = self.columns().to_vec();
595 new_columns.insert(index, arr);
596 Self::try_new(new_schema, new_columns)
597 }
598
599 fn try_new_from_struct_array(&self, arr: StructArray) -> Result<Self> {
600 let schema = Arc::new(Schema::new_with_metadata(
601 arr.fields().to_vec(),
602 self.schema().metadata.clone(),
603 ));
604 let batch = Self::from(arr);
605 batch.with_schema(schema)
606 }
607
608 fn merge(&self, other: &Self) -> Result<Self> {
609 if self.num_rows() != other.num_rows() {
610 return Err(ArrowError::InvalidArgumentError(format!(
611 "Attempt to merge two RecordBatch with different sizes: {} != {}",
612 self.num_rows(),
613 other.num_rows()
614 )));
615 }
616 let left_struct_array: StructArray = self.clone().into();
617 let right_struct_array: StructArray = other.clone().into();
618 self.try_new_from_struct_array(merge(&left_struct_array, &right_struct_array))
619 }
620
621 fn merge_with_schema(&self, other: &RecordBatch, schema: &Schema) -> Result<RecordBatch> {
622 if self.num_rows() != other.num_rows() {
623 return Err(ArrowError::InvalidArgumentError(format!(
624 "Attempt to merge two RecordBatch with different sizes: {} != {}",
625 self.num_rows(),
626 other.num_rows()
627 )));
628 }
629 let left_struct_array: StructArray = self.clone().into();
630 let right_struct_array: StructArray = other.clone().into();
631 self.try_new_from_struct_array(merge_with_schema(
632 &left_struct_array,
633 &right_struct_array,
634 schema.fields(),
635 ))
636 }
637
638 fn drop_column(&self, name: &str) -> Result<Self> {
639 let mut fields = vec![];
640 let mut columns = vec![];
641 for i in 0..self.schema().fields.len() {
642 if self.schema().field(i).name() != name {
643 fields.push(self.schema().field(i).clone());
644 columns.push(self.column(i).clone());
645 }
646 }
647 Self::try_new(
648 Arc::new(Schema::new_with_metadata(
649 fields,
650 self.schema().metadata().clone(),
651 )),
652 columns,
653 )
654 }
655
656 fn replace_column_by_name(&self, name: &str, column: Arc<dyn Array>) -> Result<RecordBatch> {
657 let mut columns = self.columns().to_vec();
658 let field_i = self
659 .schema()
660 .fields()
661 .iter()
662 .position(|f| f.name() == name)
663 .ok_or_else(|| ArrowError::SchemaError(format!("Field {} does not exist", name)))?;
664 columns[field_i] = column;
665 Self::try_new(self.schema(), columns)
666 }
667
668 fn replace_column_schema_by_name(
669 &self,
670 name: &str,
671 new_data_type: DataType,
672 column: Arc<dyn Array>,
673 ) -> Result<RecordBatch> {
674 let fields = self
675 .schema()
676 .fields()
677 .iter()
678 .map(|x| {
679 if x.name() != name {
680 x.clone()
681 } else {
682 let new_field = Field::new(name, new_data_type.clone(), x.is_nullable());
683 Arc::new(new_field)
684 }
685 })
686 .collect::<Vec<_>>();
687 let schema = Schema::new_with_metadata(fields, self.schema().metadata.clone());
688 let mut columns = self.columns().to_vec();
689 let field_i = self
690 .schema()
691 .fields()
692 .iter()
693 .position(|f| f.name() == name)
694 .ok_or_else(|| ArrowError::SchemaError(format!("Field {} does not exist", name)))?;
695 columns[field_i] = column;
696 Self::try_new(Arc::new(schema), columns)
697 }
698
699 fn column_by_qualified_name(&self, name: &str) -> Option<&ArrayRef> {
700 let split = name.split('.').collect::<Vec<_>>();
701 if split.is_empty() {
702 return None;
703 }
704
705 self.column_by_name(split[0])
706 .and_then(|arr| get_sub_array(arr, &split[1..]))
707 }
708
709 fn project_by_schema(&self, schema: &Schema) -> Result<Self> {
710 let struct_array: StructArray = self.clone().into();
711 self.try_new_from_struct_array(project(&struct_array, schema.fields())?)
712 }
713
714 fn metadata(&self) -> &HashMap<String, String> {
715 self.schema_ref().metadata()
716 }
717
718 fn with_metadata(&self, metadata: HashMap<String, String>) -> Result<RecordBatch> {
719 let mut schema = self.schema_ref().as_ref().clone();
720 schema.metadata = metadata;
721 Self::try_new(schema.into(), self.columns().into())
722 }
723
724 fn take(&self, indices: &UInt32Array) -> Result<Self> {
725 let struct_array: StructArray = self.clone().into();
726 let taken = take(&struct_array, indices, None)?;
727 self.try_new_from_struct_array(taken.as_struct().clone())
728 }
729}
730
731fn project(struct_array: &StructArray, fields: &Fields) -> Result<StructArray> {
732 if fields.is_empty() {
733 return Ok(StructArray::new_empty_fields(
734 struct_array.len(),
735 struct_array.nulls().cloned(),
736 ));
737 }
738 let mut columns: Vec<ArrayRef> = vec![];
739 for field in fields.iter() {
740 if let Some(col) = struct_array.column_by_name(field.name()) {
741 match field.data_type() {
742 DataType::Struct(subfields) => {
744 let projected = project(col.as_struct(), subfields)?;
745 columns.push(Arc::new(projected));
746 }
747 _ => {
748 columns.push(col.clone());
749 }
750 }
751 } else {
752 return Err(ArrowError::SchemaError(format!(
753 "field {} does not exist in the RecordBatch",
754 field.name()
755 )));
756 }
757 }
758 StructArray::try_new(fields.clone(), columns, None)
759}
760
761fn lists_have_same_offsets_helper<T: OffsetSizeTrait>(left: &dyn Array, right: &dyn Array) -> bool {
762 let left_list: &GenericListArray<T> = left.as_list();
763 let right_list: &GenericListArray<T> = right.as_list();
764 left_list.offsets().inner() == right_list.offsets().inner()
765}
766
767fn merge_list_structs_helper<T: OffsetSizeTrait>(
768 left: &dyn Array,
769 right: &dyn Array,
770 items_field_name: impl Into<String>,
771 items_nullable: bool,
772) -> Arc<dyn Array> {
773 let left_list: &GenericListArray<T> = left.as_list();
774 let right_list: &GenericListArray<T> = right.as_list();
775 let left_struct = left_list.values();
776 let right_struct = right_list.values();
777 let left_struct_arr = left_struct.as_struct();
778 let right_struct_arr = right_struct.as_struct();
779 let merged_items = Arc::new(merge(left_struct_arr, right_struct_arr));
780 let items_field = Arc::new(Field::new(
781 items_field_name,
782 merged_items.data_type().clone(),
783 items_nullable,
784 ));
785 Arc::new(GenericListArray::<T>::new(
786 items_field,
787 left_list.offsets().clone(),
788 merged_items,
789 left_list.nulls().cloned(),
790 ))
791}
792
793fn merge_list_struct_null_helper<T: OffsetSizeTrait>(
794 left: &dyn Array,
795 right: &dyn Array,
796 not_null: &dyn Array,
797 items_field_name: impl Into<String>,
798) -> Arc<dyn Array> {
799 let left_list: &GenericListArray<T> = left.as_list::<T>();
800 let not_null_list = not_null.as_list::<T>();
801 let right_list = right.as_list::<T>();
802
803 let left_struct = left_list.values().as_struct();
804 let not_null_struct: &StructArray = not_null_list.values().as_struct();
805 let right_struct = right_list.values().as_struct();
806
807 let values_len = not_null_list.values().len();
808 let mut merged_fields =
809 Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns());
810 let mut merged_columns =
811 Vec::with_capacity(not_null_struct.num_columns() + right_struct.num_columns());
812
813 for (_, field) in left_struct.columns().iter().zip(left_struct.fields()) {
814 merged_fields.push(field.clone());
815 if let Some(val) = not_null_struct.column_by_name(field.name()) {
816 merged_columns.push(val.clone());
817 } else {
818 merged_columns.push(new_null_array(field.data_type(), values_len))
819 }
820 }
821 for (_, field) in right_struct
822 .columns()
823 .iter()
824 .zip(right_struct.fields())
825 .filter(|(_, field)| left_struct.column_by_name(field.name()).is_none())
826 {
827 merged_fields.push(field.clone());
828 if let Some(val) = not_null_struct.column_by_name(field.name()) {
829 merged_columns.push(val.clone());
830 } else {
831 merged_columns.push(new_null_array(field.data_type(), values_len));
832 }
833 }
834
835 let merged_struct = Arc::new(StructArray::new(
836 Fields::from(merged_fields),
837 merged_columns,
838 not_null_struct.nulls().cloned(),
839 ));
840 let items_field = Arc::new(Field::new(
841 items_field_name,
842 merged_struct.data_type().clone(),
843 true,
844 ));
845 Arc::new(GenericListArray::<T>::new(
846 items_field,
847 not_null_list.offsets().clone(),
848 merged_struct,
849 not_null_list.nulls().cloned(),
850 ))
851}
852
853fn merge_list_struct_null(
854 left: &dyn Array,
855 right: &dyn Array,
856 not_null: &dyn Array,
857) -> Arc<dyn Array> {
858 match left.data_type() {
859 DataType::List(left_field) => {
860 merge_list_struct_null_helper::<i32>(left, right, not_null, left_field.name())
861 }
862 DataType::LargeList(left_field) => {
863 merge_list_struct_null_helper::<i64>(left, right, not_null, left_field.name())
864 }
865 _ => unreachable!(),
866 }
867}
868
869fn merge_list_struct(left: &dyn Array, right: &dyn Array) -> Arc<dyn Array> {
870 if left.null_count() == left.len() {
874 return merge_list_struct_null(left, right, right);
875 } else if right.null_count() == right.len() {
876 return merge_list_struct_null(left, right, left);
877 }
878 match (left.data_type(), right.data_type()) {
879 (DataType::List(left_field), DataType::List(_)) => {
880 if !lists_have_same_offsets_helper::<i32>(left, right) {
881 panic!("Attempt to merge list struct arrays which do not have same offsets");
882 }
883 merge_list_structs_helper::<i32>(
884 left,
885 right,
886 left_field.name(),
887 left_field.is_nullable(),
888 )
889 }
890 (DataType::LargeList(left_field), DataType::LargeList(_)) => {
891 if !lists_have_same_offsets_helper::<i64>(left, right) {
892 panic!("Attempt to merge list struct arrays which do not have same offsets");
893 }
894 merge_list_structs_helper::<i64>(
895 left,
896 right,
897 left_field.name(),
898 left_field.is_nullable(),
899 )
900 }
901 _ => unreachable!(),
902 }
903}
904
905fn merge(left_struct_array: &StructArray, right_struct_array: &StructArray) -> StructArray {
906 let mut fields: Vec<Field> = vec![];
907 let mut columns: Vec<ArrayRef> = vec![];
908 let right_fields = right_struct_array.fields();
909 let right_columns = right_struct_array.columns();
910
911 for (left_field, left_column) in left_struct_array
913 .fields()
914 .iter()
915 .zip(left_struct_array.columns().iter())
916 {
917 match right_fields
918 .iter()
919 .position(|f| f.name() == left_field.name())
920 {
921 Some(right_index) => {
923 let right_field = right_fields.get(right_index).unwrap();
924 let right_column = right_columns.get(right_index).unwrap();
925 match (left_field.data_type(), right_field.data_type()) {
927 (DataType::Struct(_), DataType::Struct(_)) => {
928 let left_sub_array = left_column.as_struct();
929 let right_sub_array = right_column.as_struct();
930 let merged_sub_array = merge(left_sub_array, right_sub_array);
931 fields.push(Field::new(
932 left_field.name(),
933 merged_sub_array.data_type().clone(),
934 left_field.is_nullable(),
935 ));
936 columns.push(Arc::new(merged_sub_array) as ArrayRef);
937 }
938 (DataType::List(left_list), DataType::List(right_list))
939 if left_list.data_type().is_struct()
940 && right_list.data_type().is_struct() =>
941 {
942 if left_list.data_type() == right_list.data_type() {
944 fields.push(left_field.as_ref().clone());
945 columns.push(left_column.clone());
946 }
947 let merged_sub_array = merge_list_struct(&left_column, &right_column);
951
952 fields.push(Field::new(
953 left_field.name(),
954 merged_sub_array.data_type().clone(),
955 left_field.is_nullable(),
956 ));
957 columns.push(merged_sub_array);
958 }
959 _ => {
961 fields.push(left_field.as_ref().clone());
963 columns.push(left_column.clone());
964 }
965 }
966 }
967 None => {
968 fields.push(left_field.as_ref().clone());
969 columns.push(left_column.clone());
970 }
971 }
972 }
973
974 right_fields
976 .iter()
977 .zip(right_columns.iter())
978 .for_each(|(field, column)| {
979 if !left_struct_array
981 .fields()
982 .iter()
983 .any(|f| f.name() == field.name())
984 {
985 fields.push(field.as_ref().clone());
986 columns.push(column.clone() as ArrayRef);
987 }
988 });
989
990 let zipped: Vec<(FieldRef, ArrayRef)> = fields
991 .iter()
992 .cloned()
993 .map(Arc::new)
994 .zip(columns.iter().cloned())
995 .collect::<Vec<_>>();
996 StructArray::from(zipped)
997}
998
999fn merge_with_schema(
1000 left_struct_array: &StructArray,
1001 right_struct_array: &StructArray,
1002 fields: &Fields,
1003) -> StructArray {
1004 fn same_type_kind(left: &DataType, right: &DataType) -> bool {
1006 match (left, right) {
1007 (DataType::Struct(_), DataType::Struct(_)) => true,
1008 (DataType::Struct(_), _) => false,
1009 (_, DataType::Struct(_)) => false,
1010 _ => true,
1011 }
1012 }
1013
1014 let mut output_fields: Vec<Field> = Vec::with_capacity(fields.len());
1015 let mut columns: Vec<ArrayRef> = Vec::with_capacity(fields.len());
1016
1017 let left_fields = left_struct_array.fields();
1018 let left_columns = left_struct_array.columns();
1019 let right_fields = right_struct_array.fields();
1020 let right_columns = right_struct_array.columns();
1021
1022 for field in fields {
1023 let left_match_idx = left_fields.iter().position(|f| {
1024 f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
1025 });
1026 let right_match_idx = right_fields.iter().position(|f| {
1027 f.name() == field.name() && same_type_kind(f.data_type(), field.data_type())
1028 });
1029
1030 match (left_match_idx, right_match_idx) {
1031 (None, Some(right_idx)) => {
1032 output_fields.push(right_fields[right_idx].as_ref().clone());
1033 columns.push(right_columns[right_idx].clone());
1034 }
1035 (Some(left_idx), None) => {
1036 output_fields.push(left_fields[left_idx].as_ref().clone());
1037 columns.push(left_columns[left_idx].clone());
1038 }
1039 (Some(left_idx), Some(right_idx)) => {
1040 if let DataType::Struct(child_fields) = field.data_type() {
1041 let left_sub_array = left_columns[left_idx].as_struct();
1042 let right_sub_array = right_columns[right_idx].as_struct();
1043 let merged_sub_array =
1044 merge_with_schema(left_sub_array, right_sub_array, child_fields);
1045 output_fields.push(Field::new(
1046 field.name(),
1047 merged_sub_array.data_type().clone(),
1048 field.is_nullable(),
1049 ));
1050 columns.push(Arc::new(merged_sub_array) as ArrayRef);
1051 } else {
1052 output_fields.push(left_fields[left_idx].as_ref().clone());
1053 columns.push(left_columns[left_idx].clone());
1054 }
1055 }
1056 (None, None) => {
1057 }
1059 }
1060 }
1061
1062 let zipped: Vec<(FieldRef, ArrayRef)> = output_fields
1063 .into_iter()
1064 .map(Arc::new)
1065 .zip(columns)
1066 .collect::<Vec<_>>();
1067 StructArray::from(zipped)
1068}
1069
1070fn get_sub_array<'a>(array: &'a ArrayRef, components: &[&str]) -> Option<&'a ArrayRef> {
1071 if components.is_empty() {
1072 return Some(array);
1073 }
1074 if !matches!(array.data_type(), DataType::Struct(_)) {
1075 return None;
1076 }
1077 let struct_arr = array.as_struct();
1078 struct_arr
1079 .column_by_name(components[0])
1080 .and_then(|arr| get_sub_array(arr, &components[1..]))
1081}
1082
1083pub fn interleave_batches(
1087 batches: &[RecordBatch],
1088 indices: &[(usize, usize)],
1089) -> Result<RecordBatch> {
1090 let first_batch = batches.first().ok_or_else(|| {
1091 ArrowError::InvalidArgumentError("Cannot interleave zero RecordBatches".to_string())
1092 })?;
1093 let schema = first_batch.schema();
1094 let num_columns = first_batch.num_columns();
1095 let mut columns = Vec::with_capacity(num_columns);
1096 let mut chunks = Vec::with_capacity(batches.len());
1097
1098 for i in 0..num_columns {
1099 for batch in batches {
1100 chunks.push(batch.column(i).as_ref());
1101 }
1102 let new_column = interleave(&chunks, indices)?;
1103 columns.push(new_column);
1104 chunks.clear();
1105 }
1106
1107 RecordBatch::try_new(schema, columns)
1108}
1109
1110pub trait BufferExt {
1111 fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self;
1126
1127 fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self;
1136}
1137
1138fn is_pwr_two(n: u64) -> bool {
1139 n & (n - 1) == 0
1140}
1141
1142impl BufferExt for arrow_buffer::Buffer {
1143 fn from_bytes_bytes(bytes: bytes::Bytes, bytes_per_value: u64) -> Self {
1144 if is_pwr_two(bytes_per_value) && bytes.as_ptr().align_offset(bytes_per_value as usize) != 0
1145 {
1146 let size_bytes = bytes.len();
1148 Self::copy_bytes_bytes(bytes, size_bytes)
1149 } else {
1150 unsafe {
1153 Self::from_custom_allocation(
1154 NonNull::new(bytes.as_ptr() as _).expect("should be a valid pointer"),
1155 bytes.len(),
1156 Arc::new(bytes),
1157 )
1158 }
1159 }
1160 }
1161
1162 fn copy_bytes_bytes(bytes: bytes::Bytes, size_bytes: usize) -> Self {
1163 assert!(size_bytes >= bytes.len());
1164 let mut buf = MutableBuffer::with_capacity(size_bytes);
1165 let to_fill = size_bytes - bytes.len();
1166 buf.extend(bytes);
1167 buf.extend(std::iter::repeat_n(0_u8, to_fill));
1168 Self::from(buf)
1169 }
1170}
1171
1172#[cfg(test)]
1173mod tests {
1174 use super::*;
1175 use arrow_array::{new_empty_array, new_null_array, Int32Array, ListArray, StringArray};
1176 use arrow_buffer::OffsetBuffer;
1177
1178 #[test]
1179 fn test_merge_recursive() {
1180 let a_array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
1181 let e_array = Int32Array::from(vec![Some(4), Some(5), Some(6)]);
1182 let c_array = Int32Array::from(vec![Some(7), Some(8), Some(9)]);
1183 let d_array = StringArray::from(vec![Some("a"), Some("b"), Some("c")]);
1184
1185 let left_schema = Schema::new(vec![
1186 Field::new("a", DataType::Int32, true),
1187 Field::new(
1188 "b",
1189 DataType::Struct(vec![Field::new("c", DataType::Int32, true)].into()),
1190 true,
1191 ),
1192 ]);
1193 let left_batch = RecordBatch::try_new(
1194 Arc::new(left_schema),
1195 vec![
1196 Arc::new(a_array.clone()),
1197 Arc::new(StructArray::from(vec![(
1198 Arc::new(Field::new("c", DataType::Int32, true)),
1199 Arc::new(c_array.clone()) as ArrayRef,
1200 )])),
1201 ],
1202 )
1203 .unwrap();
1204
1205 let right_schema = Schema::new(vec![
1206 Field::new("e", DataType::Int32, true),
1207 Field::new(
1208 "b",
1209 DataType::Struct(vec![Field::new("d", DataType::Utf8, true)].into()),
1210 true,
1211 ),
1212 ]);
1213 let right_batch = RecordBatch::try_new(
1214 Arc::new(right_schema),
1215 vec![
1216 Arc::new(e_array.clone()),
1217 Arc::new(StructArray::from(vec![(
1218 Arc::new(Field::new("d", DataType::Utf8, true)),
1219 Arc::new(d_array.clone()) as ArrayRef,
1220 )])) as ArrayRef,
1221 ],
1222 )
1223 .unwrap();
1224
1225 let merged_schema = Schema::new(vec![
1226 Field::new("a", DataType::Int32, true),
1227 Field::new(
1228 "b",
1229 DataType::Struct(
1230 vec![
1231 Field::new("c", DataType::Int32, true),
1232 Field::new("d", DataType::Utf8, true),
1233 ]
1234 .into(),
1235 ),
1236 true,
1237 ),
1238 Field::new("e", DataType::Int32, true),
1239 ]);
1240 let merged_batch = RecordBatch::try_new(
1241 Arc::new(merged_schema),
1242 vec![
1243 Arc::new(a_array) as ArrayRef,
1244 Arc::new(StructArray::from(vec![
1245 (
1246 Arc::new(Field::new("c", DataType::Int32, true)),
1247 Arc::new(c_array) as ArrayRef,
1248 ),
1249 (
1250 Arc::new(Field::new("d", DataType::Utf8, true)),
1251 Arc::new(d_array) as ArrayRef,
1252 ),
1253 ])) as ArrayRef,
1254 Arc::new(e_array) as ArrayRef,
1255 ],
1256 )
1257 .unwrap();
1258
1259 let result = left_batch.merge(&right_batch).unwrap();
1260 assert_eq!(result, merged_batch);
1261 }
1262
1263 #[test]
1264 fn test_merge_with_schema() {
1265 fn test_batch(names: &[&str], types: &[DataType]) -> (Schema, RecordBatch) {
1266 let fields: Fields = names
1267 .iter()
1268 .zip(types)
1269 .map(|(name, ty)| Field::new(name.to_string(), ty.clone(), false))
1270 .collect();
1271 let schema = Schema::new(vec![Field::new(
1272 "struct",
1273 DataType::Struct(fields.clone()),
1274 false,
1275 )]);
1276 let children = types.iter().map(new_empty_array).collect::<Vec<_>>();
1277 let batch = RecordBatch::try_new(
1278 Arc::new(schema.clone()),
1279 vec![Arc::new(StructArray::new(fields, children, None)) as ArrayRef],
1280 );
1281 (schema, batch.unwrap())
1282 }
1283
1284 let (_, left_batch) = test_batch(&["a", "b"], &[DataType::Int32, DataType::Int64]);
1285 let (_, right_batch) = test_batch(&["c", "b"], &[DataType::Int32, DataType::Int64]);
1286 let (output_schema, _) = test_batch(
1287 &["b", "a", "c"],
1288 &[DataType::Int64, DataType::Int32, DataType::Int32],
1289 );
1290
1291 let merged = left_batch
1293 .merge_with_schema(&right_batch, &output_schema)
1294 .unwrap();
1295 assert_eq!(merged.schema().as_ref(), &output_schema);
1296
1297 let (naive_schema, _) = test_batch(
1299 &["a", "b", "c"],
1300 &[DataType::Int32, DataType::Int64, DataType::Int32],
1301 );
1302 let merged = left_batch.merge(&right_batch).unwrap();
1303 assert_eq!(merged.schema().as_ref(), &naive_schema);
1304 }
1305
1306 #[test]
1307 fn test_merge_list_struct() {
1308 let x_field = Arc::new(Field::new("x", DataType::Int32, true));
1309 let y_field = Arc::new(Field::new("y", DataType::Int32, true));
1310 let x_struct_field = Arc::new(Field::new(
1311 "item",
1312 DataType::Struct(Fields::from(vec![x_field.clone()])),
1313 true,
1314 ));
1315 let y_struct_field = Arc::new(Field::new(
1316 "item",
1317 DataType::Struct(Fields::from(vec![y_field.clone()])),
1318 true,
1319 ));
1320 let both_struct_field = Arc::new(Field::new(
1321 "item",
1322 DataType::Struct(Fields::from(vec![x_field.clone(), y_field.clone()])),
1323 true,
1324 ));
1325 let left_schema = Schema::new(vec![Field::new(
1326 "list_struct",
1327 DataType::List(x_struct_field.clone()),
1328 true,
1329 )]);
1330 let right_schema = Schema::new(vec![Field::new(
1331 "list_struct",
1332 DataType::List(y_struct_field.clone()),
1333 true,
1334 )]);
1335 let both_schema = Schema::new(vec![Field::new(
1336 "list_struct",
1337 DataType::List(both_struct_field.clone()),
1338 true,
1339 )]);
1340
1341 let x = Arc::new(Int32Array::from(vec![1]));
1342 let y = Arc::new(Int32Array::from(vec![2]));
1343 let x_struct = Arc::new(StructArray::new(
1344 Fields::from(vec![x_field.clone()]),
1345 vec![x.clone()],
1346 None,
1347 ));
1348 let y_struct = Arc::new(StructArray::new(
1349 Fields::from(vec![y_field.clone()]),
1350 vec![y.clone()],
1351 None,
1352 ));
1353 let both_struct = Arc::new(StructArray::new(
1354 Fields::from(vec![x_field.clone(), y_field.clone()]),
1355 vec![x.clone(), y],
1356 None,
1357 ));
1358 let both_null_struct = Arc::new(StructArray::new(
1359 Fields::from(vec![x_field, y_field]),
1360 vec![x, Arc::new(new_null_array(&DataType::Int32, 1))],
1361 None,
1362 ));
1363 let offsets = OffsetBuffer::from_lengths([1]);
1364 let x_s_list = ListArray::new(x_struct_field, offsets.clone(), x_struct, None);
1365 let y_s_list = ListArray::new(y_struct_field, offsets.clone(), y_struct, None);
1366 let both_list = ListArray::new(
1367 both_struct_field.clone(),
1368 offsets.clone(),
1369 both_struct,
1370 None,
1371 );
1372 let both_null_list = ListArray::new(both_struct_field, offsets, both_null_struct, None);
1373 let x_batch =
1374 RecordBatch::try_new(Arc::new(left_schema), vec![Arc::new(x_s_list)]).unwrap();
1375 let y_batch = RecordBatch::try_new(
1376 Arc::new(right_schema.clone()),
1377 vec![Arc::new(y_s_list.clone())],
1378 )
1379 .unwrap();
1380 let merged = x_batch.merge(&y_batch).unwrap();
1381 let expected =
1382 RecordBatch::try_new(Arc::new(both_schema.clone()), vec![Arc::new(both_list)]).unwrap();
1383 assert_eq!(merged, expected);
1384
1385 let y_null_list = new_null_array(y_s_list.data_type(), 1);
1386 let y_null_batch =
1387 RecordBatch::try_new(Arc::new(right_schema), vec![Arc::new(y_null_list.clone())])
1388 .unwrap();
1389 let expected =
1390 RecordBatch::try_new(Arc::new(both_schema), vec![Arc::new(both_null_list)]).unwrap();
1391 let merged = x_batch.merge(&y_null_batch).unwrap();
1392 assert_eq!(merged, expected);
1393 }
1394
1395 #[test]
1396 fn test_take_record_batch() {
1397 let schema = Arc::new(Schema::new(vec![
1398 Field::new("a", DataType::Int32, true),
1399 Field::new("b", DataType::Utf8, true),
1400 ]));
1401 let batch = RecordBatch::try_new(
1402 schema.clone(),
1403 vec![
1404 Arc::new(Int32Array::from_iter_values(0..20)),
1405 Arc::new(StringArray::from_iter_values(
1406 (0..20).map(|i| format!("str-{}", i)),
1407 )),
1408 ],
1409 )
1410 .unwrap();
1411 let taken = batch.take(&(vec![1_u32, 5_u32, 10_u32].into())).unwrap();
1412 assert_eq!(
1413 taken,
1414 RecordBatch::try_new(
1415 schema,
1416 vec![
1417 Arc::new(Int32Array::from(vec![1, 5, 10])),
1418 Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
1419 ],
1420 )
1421 .unwrap()
1422 )
1423 }
1424
1425 #[test]
1426 fn test_schema_project_by_schema() {
1427 let metadata = [("key".to_string(), "value".to_string())];
1428 let schema = Arc::new(
1429 Schema::new(vec![
1430 Field::new("a", DataType::Int32, true),
1431 Field::new("b", DataType::Utf8, true),
1432 ])
1433 .with_metadata(metadata.clone().into()),
1434 );
1435 let batch = RecordBatch::try_new(
1436 schema,
1437 vec![
1438 Arc::new(Int32Array::from_iter_values(0..20)),
1439 Arc::new(StringArray::from_iter_values(
1440 (0..20).map(|i| format!("str-{}", i)),
1441 )),
1442 ],
1443 )
1444 .unwrap();
1445
1446 let empty_schema = Schema::empty();
1448 let empty_projected = batch.project_by_schema(&empty_schema).unwrap();
1449 let expected_schema = empty_schema.with_metadata(metadata.clone().into());
1450 assert_eq!(
1451 empty_projected,
1452 RecordBatch::from(StructArray::new_empty_fields(batch.num_rows(), None))
1453 .with_schema(Arc::new(expected_schema))
1454 .unwrap()
1455 );
1456
1457 let reordered_schema = Schema::new(vec![
1459 Field::new("b", DataType::Utf8, true),
1460 Field::new("a", DataType::Int32, true),
1461 ]);
1462 let reordered_projected = batch.project_by_schema(&reordered_schema).unwrap();
1463 let expected_schema = Arc::new(reordered_schema.with_metadata(metadata.clone().into()));
1464 assert_eq!(
1465 reordered_projected,
1466 RecordBatch::try_new(
1467 expected_schema,
1468 vec![
1469 Arc::new(StringArray::from_iter_values(
1470 (0..20).map(|i| format!("str-{}", i)),
1471 )),
1472 Arc::new(Int32Array::from_iter_values(0..20)),
1473 ],
1474 )
1475 .unwrap()
1476 );
1477
1478 let sub_schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
1480 let sub_projected = batch.project_by_schema(&sub_schema).unwrap();
1481 let expected_schema = Arc::new(sub_schema.with_metadata(metadata.into()));
1482 assert_eq!(
1483 sub_projected,
1484 RecordBatch::try_new(
1485 expected_schema,
1486 vec![Arc::new(Int32Array::from_iter_values(0..20))],
1487 )
1488 .unwrap()
1489 );
1490 }
1491}