1use crate::array::print_long_array;
19use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch};
20use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer};
21use arrow_data::{ArrayData, ArrayDataBuilder};
22use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields};
23use std::sync::Arc;
24use std::{any::Any, ops::Index};
25
26#[derive(Clone)]
77pub struct StructArray {
78 len: usize,
79 data_type: DataType,
80 nulls: Option<NullBuffer>,
81 fields: Vec<ArrayRef>,
82}
83
84impl StructArray {
85 pub fn new(fields: Fields, arrays: Vec<ArrayRef>, nulls: Option<NullBuffer>) -> Self {
91 Self::try_new(fields, arrays, nulls).unwrap()
92 }
93
94 pub fn try_new(
106 fields: Fields,
107 arrays: Vec<ArrayRef>,
108 nulls: Option<NullBuffer>,
109 ) -> Result<Self, ArrowError> {
110 if fields.len() != arrays.len() {
111 return Err(ArrowError::InvalidArgumentError(format!(
112 "Incorrect number of arrays for StructArray fields, expected {} got {}",
113 fields.len(),
114 arrays.len()
115 )));
116 }
117 let len = arrays.first().map(|x| x.len()).unwrap_or_default();
118
119 if let Some(n) = nulls.as_ref() {
120 if n.len() != len {
121 return Err(ArrowError::InvalidArgumentError(format!(
122 "Incorrect number of nulls for StructArray, expected {len} got {}",
123 n.len(),
124 )));
125 }
126 }
127
128 for (f, a) in fields.iter().zip(&arrays) {
129 if f.data_type() != a.data_type() {
130 return Err(ArrowError::InvalidArgumentError(format!(
131 "Incorrect datatype for StructArray field {:?}, expected {} got {}",
132 f.name(),
133 f.data_type(),
134 a.data_type()
135 )));
136 }
137
138 if a.len() != len {
139 return Err(ArrowError::InvalidArgumentError(format!(
140 "Incorrect array length for StructArray field {:?}, expected {} got {}",
141 f.name(),
142 len,
143 a.len()
144 )));
145 }
146
147 if !f.is_nullable() {
148 if let Some(a) = a.logical_nulls() {
149 if !nulls.as_ref().map(|n| n.contains(&a)).unwrap_or_default() {
150 return Err(ArrowError::InvalidArgumentError(format!(
151 "Found unmasked nulls for non-nullable StructArray field {:?}",
152 f.name()
153 )));
154 }
155 }
156 }
157 }
158
159 Ok(Self {
160 len,
161 data_type: DataType::Struct(fields),
162 nulls: nulls.filter(|n| n.null_count() > 0),
163 fields: arrays,
164 })
165 }
166
167 pub fn new_null(fields: Fields, len: usize) -> Self {
169 let arrays = fields
170 .iter()
171 .map(|f| new_null_array(f.data_type(), len))
172 .collect();
173
174 Self {
175 len,
176 data_type: DataType::Struct(fields),
177 nulls: Some(NullBuffer::new_null(len)),
178 fields: arrays,
179 }
180 }
181
182 pub unsafe fn new_unchecked(
188 fields: Fields,
189 arrays: Vec<ArrayRef>,
190 nulls: Option<NullBuffer>,
191 ) -> Self {
192 let len = arrays.first().map(|x| x.len()).unwrap_or_default();
193 Self {
194 len,
195 data_type: DataType::Struct(fields),
196 nulls,
197 fields: arrays,
198 }
199 }
200
201 pub fn new_empty_fields(len: usize, nulls: Option<NullBuffer>) -> Self {
207 if let Some(n) = &nulls {
208 assert_eq!(len, n.len())
209 }
210 Self {
211 len,
212 data_type: DataType::Struct(Fields::empty()),
213 fields: vec![],
214 nulls,
215 }
216 }
217
218 pub fn into_parts(self) -> (Fields, Vec<ArrayRef>, Option<NullBuffer>) {
220 let f = match self.data_type {
221 DataType::Struct(f) => f,
222 _ => unreachable!(),
223 };
224 (f, self.fields, self.nulls)
225 }
226
227 pub fn column(&self, pos: usize) -> &ArrayRef {
229 &self.fields[pos]
230 }
231
232 pub fn num_columns(&self) -> usize {
234 self.fields.len()
235 }
236
237 pub fn columns(&self) -> &[ArrayRef] {
239 &self.fields
240 }
241
242 #[deprecated(note = "Use columns().to_vec()")]
244 pub fn columns_ref(&self) -> Vec<ArrayRef> {
245 self.columns().to_vec()
246 }
247
248 pub fn column_names(&self) -> Vec<&str> {
250 match self.data_type() {
251 DataType::Struct(fields) => fields
252 .iter()
253 .map(|f| f.name().as_str())
254 .collect::<Vec<&str>>(),
255 _ => unreachable!("Struct array's data type is not struct!"),
256 }
257 }
258
259 pub fn fields(&self) -> &Fields {
261 match self.data_type() {
262 DataType::Struct(f) => f,
263 _ => unreachable!(),
264 }
265 }
266
267 pub fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef> {
273 self.column_names()
274 .iter()
275 .position(|c| c == &column_name)
276 .map(|pos| self.column(pos))
277 }
278
279 pub fn slice(&self, offset: usize, len: usize) -> Self {
281 assert!(
282 offset.saturating_add(len) <= self.len,
283 "the length + offset of the sliced StructArray cannot exceed the existing length"
284 );
285
286 let fields = self.fields.iter().map(|a| a.slice(offset, len)).collect();
287
288 Self {
289 len,
290 data_type: self.data_type.clone(),
291 nulls: self.nulls.as_ref().map(|n| n.slice(offset, len)),
292 fields,
293 }
294 }
295}
296
297impl From<ArrayData> for StructArray {
298 fn from(data: ArrayData) -> Self {
299 let fields = data
300 .child_data()
301 .iter()
302 .map(|cd| make_array(cd.clone()))
303 .collect();
304
305 Self {
306 len: data.len(),
307 data_type: data.data_type().clone(),
308 nulls: data.nulls().cloned(),
309 fields,
310 }
311 }
312}
313
314impl From<StructArray> for ArrayData {
315 fn from(array: StructArray) -> Self {
316 let builder = ArrayDataBuilder::new(array.data_type)
317 .len(array.len)
318 .nulls(array.nulls)
319 .child_data(array.fields.iter().map(|x| x.to_data()).collect());
320
321 unsafe { builder.build_unchecked() }
322 }
323}
324
325impl TryFrom<Vec<(&str, ArrayRef)>> for StructArray {
326 type Error = ArrowError;
327
328 fn try_from(values: Vec<(&str, ArrayRef)>) -> Result<Self, ArrowError> {
330 let (fields, arrays): (Vec<_>, _) = values
331 .into_iter()
332 .map(|(name, array)| {
333 (
334 Field::new(name, array.data_type().clone(), array.is_nullable()),
335 array,
336 )
337 })
338 .unzip();
339
340 StructArray::try_new(fields.into(), arrays, None)
341 }
342}
343
344impl Array for StructArray {
345 fn as_any(&self) -> &dyn Any {
346 self
347 }
348
349 fn to_data(&self) -> ArrayData {
350 self.clone().into()
351 }
352
353 fn into_data(self) -> ArrayData {
354 self.into()
355 }
356
357 fn data_type(&self) -> &DataType {
358 &self.data_type
359 }
360
361 fn slice(&self, offset: usize, length: usize) -> ArrayRef {
362 Arc::new(self.slice(offset, length))
363 }
364
365 fn len(&self) -> usize {
366 self.len
367 }
368
369 fn is_empty(&self) -> bool {
370 self.len == 0
371 }
372
373 fn shrink_to_fit(&mut self) {
374 if let Some(nulls) = &mut self.nulls {
375 nulls.shrink_to_fit();
376 }
377 self.fields.iter_mut().for_each(|n| n.shrink_to_fit());
378 }
379
380 fn offset(&self) -> usize {
381 0
382 }
383
384 fn nulls(&self) -> Option<&NullBuffer> {
385 self.nulls.as_ref()
386 }
387
388 fn logical_null_count(&self) -> usize {
389 self.null_count()
391 }
392
393 fn get_buffer_memory_size(&self) -> usize {
394 let mut size = self.fields.iter().map(|a| a.get_buffer_memory_size()).sum();
395 if let Some(n) = self.nulls.as_ref() {
396 size += n.buffer().capacity();
397 }
398 size
399 }
400
401 fn get_array_memory_size(&self) -> usize {
402 let mut size = self.fields.iter().map(|a| a.get_array_memory_size()).sum();
403 size += std::mem::size_of::<Self>();
404 if let Some(n) = self.nulls.as_ref() {
405 size += n.buffer().capacity();
406 }
407 size
408 }
409}
410
411impl From<Vec<(FieldRef, ArrayRef)>> for StructArray {
412 fn from(v: Vec<(FieldRef, ArrayRef)>) -> Self {
413 let (fields, arrays): (Vec<_>, _) = v.into_iter().unzip();
414 StructArray::new(fields.into(), arrays, None)
415 }
416}
417
418impl std::fmt::Debug for StructArray {
419 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
420 writeln!(f, "StructArray")?;
421 writeln!(f, "-- validity: ")?;
422 writeln!(f, "[")?;
423 print_long_array(self, f, |_array, _index, f| write!(f, "valid"))?;
424 writeln!(f, "]\n[")?;
425 for (child_index, name) in self.column_names().iter().enumerate() {
426 let column = self.column(child_index);
427 writeln!(
428 f,
429 "-- child {}: \"{}\" ({:?})",
430 child_index,
431 name,
432 column.data_type()
433 )?;
434 std::fmt::Debug::fmt(column, f)?;
435 writeln!(f)?;
436 }
437 write!(f, "]")
438 }
439}
440
441impl From<(Vec<(FieldRef, ArrayRef)>, Buffer)> for StructArray {
442 fn from(pair: (Vec<(FieldRef, ArrayRef)>, Buffer)) -> Self {
443 let len = pair.0.first().map(|x| x.1.len()).unwrap_or_default();
444 let (fields, arrays): (Vec<_>, Vec<_>) = pair.0.into_iter().unzip();
445 let nulls = NullBuffer::new(BooleanBuffer::new(pair.1, 0, len));
446 Self::new(fields.into(), arrays, Some(nulls))
447 }
448}
449
450impl From<RecordBatch> for StructArray {
451 fn from(value: RecordBatch) -> Self {
452 Self {
453 len: value.num_rows(),
454 data_type: DataType::Struct(value.schema().fields().clone()),
455 nulls: None,
456 fields: value.columns().to_vec(),
457 }
458 }
459}
460
461impl Index<&str> for StructArray {
462 type Output = ArrayRef;
463
464 fn index(&self, name: &str) -> &Self::Output {
474 self.column_by_name(name).unwrap()
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 use crate::{BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, StringArray};
483 use arrow_buffer::ToByteSlice;
484
485 #[test]
486 fn test_struct_array_builder() {
487 let boolean_array = BooleanArray::from(vec![false, false, true, true]);
488 let int_array = Int64Array::from(vec![42, 28, 19, 31]);
489
490 let fields = vec![
491 Field::new("a", DataType::Boolean, false),
492 Field::new("b", DataType::Int64, false),
493 ];
494 let struct_array_data = ArrayData::builder(DataType::Struct(fields.into()))
495 .len(4)
496 .add_child_data(boolean_array.to_data())
497 .add_child_data(int_array.to_data())
498 .build()
499 .unwrap();
500 let struct_array = StructArray::from(struct_array_data);
501
502 assert_eq!(struct_array.column(0).as_ref(), &boolean_array);
503 assert_eq!(struct_array.column(1).as_ref(), &int_array);
504 }
505
506 #[test]
507 fn test_struct_array_from() {
508 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
509 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
510
511 let struct_array = StructArray::from(vec![
512 (
513 Arc::new(Field::new("b", DataType::Boolean, false)),
514 boolean.clone() as ArrayRef,
515 ),
516 (
517 Arc::new(Field::new("c", DataType::Int32, false)),
518 int.clone() as ArrayRef,
519 ),
520 ]);
521 assert_eq!(struct_array.column(0).as_ref(), boolean.as_ref());
522 assert_eq!(struct_array.column(1).as_ref(), int.as_ref());
523 assert_eq!(4, struct_array.len());
524 assert_eq!(0, struct_array.null_count());
525 assert_eq!(0, struct_array.offset());
526 }
527
528 #[test]
530 fn test_struct_array_index_access() {
531 let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true]));
532 let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31]));
533
534 let struct_array = StructArray::from(vec![
535 (
536 Arc::new(Field::new("b", DataType::Boolean, false)),
537 boolean.clone() as ArrayRef,
538 ),
539 (
540 Arc::new(Field::new("c", DataType::Int32, false)),
541 int.clone() as ArrayRef,
542 ),
543 ]);
544 assert_eq!(struct_array["b"].as_ref(), boolean.as_ref());
545 assert_eq!(struct_array["c"].as_ref(), int.as_ref());
546 }
547
548 #[test]
550 fn test_struct_array_from_vec() {
551 let strings: ArrayRef = Arc::new(StringArray::from(vec![
552 Some("joe"),
553 None,
554 None,
555 Some("mark"),
556 ]));
557 let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)]));
558
559 let arr =
560 StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]).unwrap();
561
562 let struct_data = arr.into_data();
563 assert_eq!(4, struct_data.len());
564 assert_eq!(0, struct_data.null_count());
565
566 let expected_string_data = ArrayData::builder(DataType::Utf8)
567 .len(4)
568 .null_bit_buffer(Some(Buffer::from(&[9_u8])))
569 .add_buffer(Buffer::from([0, 3, 3, 3, 7].to_byte_slice()))
570 .add_buffer(Buffer::from(b"joemark"))
571 .build()
572 .unwrap();
573
574 let expected_int_data = ArrayData::builder(DataType::Int32)
575 .len(4)
576 .null_bit_buffer(Some(Buffer::from(&[11_u8])))
577 .add_buffer(Buffer::from([1, 2, 0, 4].to_byte_slice()))
578 .build()
579 .unwrap();
580
581 assert_eq!(expected_string_data, struct_data.child_data()[0]);
582 assert_eq!(expected_int_data, struct_data.child_data()[1]);
583 }
584
585 #[test]
586 fn test_struct_array_from_vec_error() {
587 let strings: ArrayRef = Arc::new(StringArray::from(vec![
588 Some("joe"),
589 None,
590 None,
591 ]));
593 let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)]));
594
595 let err = StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())])
596 .unwrap_err()
597 .to_string();
598
599 assert_eq!(
600 err,
601 "Invalid argument error: Incorrect array length for StructArray field \"f2\", expected 3 got 4"
602 )
603 }
604
605 #[test]
606 #[should_panic(
607 expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean"
608 )]
609 fn test_struct_array_from_mismatched_types_single() {
610 drop(StructArray::from(vec![(
611 Arc::new(Field::new("b", DataType::Int16, false)),
612 Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc<dyn Array>,
613 )]));
614 }
615
616 #[test]
617 #[should_panic(
618 expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean"
619 )]
620 fn test_struct_array_from_mismatched_types_multiple() {
621 drop(StructArray::from(vec![
622 (
623 Arc::new(Field::new("b", DataType::Int16, false)),
624 Arc::new(BooleanArray::from(vec![false, false, true, true])) as Arc<dyn Array>,
625 ),
626 (
627 Arc::new(Field::new("c", DataType::Utf8, false)),
628 Arc::new(Int32Array::from(vec![42, 28, 19, 31])),
629 ),
630 ]));
631 }
632
633 #[test]
634 fn test_struct_array_slice() {
635 let boolean_data = ArrayData::builder(DataType::Boolean)
636 .len(5)
637 .add_buffer(Buffer::from([0b00010000]))
638 .null_bit_buffer(Some(Buffer::from([0b00010001])))
639 .build()
640 .unwrap();
641 let int_data = ArrayData::builder(DataType::Int32)
642 .len(5)
643 .add_buffer(Buffer::from([0, 28, 42, 0, 0].to_byte_slice()))
644 .null_bit_buffer(Some(Buffer::from([0b00000110])))
645 .build()
646 .unwrap();
647
648 let field_types = vec![
649 Field::new("a", DataType::Boolean, true),
650 Field::new("b", DataType::Int32, true),
651 ];
652 let struct_array_data = ArrayData::builder(DataType::Struct(field_types.into()))
653 .len(5)
654 .add_child_data(boolean_data.clone())
655 .add_child_data(int_data.clone())
656 .null_bit_buffer(Some(Buffer::from([0b00010111])))
657 .build()
658 .unwrap();
659 let struct_array = StructArray::from(struct_array_data);
660
661 assert_eq!(5, struct_array.len());
662 assert_eq!(1, struct_array.null_count());
663 assert!(struct_array.is_valid(0));
664 assert!(struct_array.is_valid(1));
665 assert!(struct_array.is_valid(2));
666 assert!(struct_array.is_null(3));
667 assert!(struct_array.is_valid(4));
668 assert_eq!(boolean_data, struct_array.column(0).to_data());
669 assert_eq!(int_data, struct_array.column(1).to_data());
670
671 let c0 = struct_array.column(0);
672 let c0 = c0.as_any().downcast_ref::<BooleanArray>().unwrap();
673 assert_eq!(5, c0.len());
674 assert_eq!(3, c0.null_count());
675 assert!(c0.is_valid(0));
676 assert!(!c0.value(0));
677 assert!(c0.is_null(1));
678 assert!(c0.is_null(2));
679 assert!(c0.is_null(3));
680 assert!(c0.is_valid(4));
681 assert!(c0.value(4));
682
683 let c1 = struct_array.column(1);
684 let c1 = c1.as_any().downcast_ref::<Int32Array>().unwrap();
685 assert_eq!(5, c1.len());
686 assert_eq!(3, c1.null_count());
687 assert!(c1.is_null(0));
688 assert!(c1.is_valid(1));
689 assert_eq!(28, c1.value(1));
690 assert!(c1.is_valid(2));
691 assert_eq!(42, c1.value(2));
692 assert!(c1.is_null(3));
693 assert!(c1.is_null(4));
694
695 let sliced_array = struct_array.slice(2, 3);
696 let sliced_array = sliced_array.as_any().downcast_ref::<StructArray>().unwrap();
697 assert_eq!(3, sliced_array.len());
698 assert_eq!(1, sliced_array.null_count());
699 assert!(sliced_array.is_valid(0));
700 assert!(sliced_array.is_null(1));
701 assert!(sliced_array.is_valid(2));
702
703 let sliced_c0 = sliced_array.column(0);
704 let sliced_c0 = sliced_c0.as_any().downcast_ref::<BooleanArray>().unwrap();
705 assert_eq!(3, sliced_c0.len());
706 assert!(sliced_c0.is_null(0));
707 assert!(sliced_c0.is_null(1));
708 assert!(sliced_c0.is_valid(2));
709 assert!(sliced_c0.value(2));
710
711 let sliced_c1 = sliced_array.column(1);
712 let sliced_c1 = sliced_c1.as_any().downcast_ref::<Int32Array>().unwrap();
713 assert_eq!(3, sliced_c1.len());
714 assert!(sliced_c1.is_valid(0));
715 assert_eq!(42, sliced_c1.value(0));
716 assert!(sliced_c1.is_null(1));
717 assert!(sliced_c1.is_null(2));
718 }
719
720 #[test]
721 #[should_panic(
722 expected = "Incorrect array length for StructArray field \\\"c\\\", expected 1 got 2"
723 )]
724 fn test_invalid_struct_child_array_lengths() {
725 drop(StructArray::from(vec![
726 (
727 Arc::new(Field::new("b", DataType::Float32, false)),
728 Arc::new(Float32Array::from(vec![1.1])) as Arc<dyn Array>,
729 ),
730 (
731 Arc::new(Field::new("c", DataType::Float64, false)),
732 Arc::new(Float64Array::from(vec![2.2, 3.3])),
733 ),
734 ]));
735 }
736
737 #[test]
738 fn test_struct_array_from_empty() {
739 let sa = StructArray::from(vec![]);
740 assert!(sa.is_empty())
741 }
742
743 #[test]
744 #[should_panic(expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\"")]
745 fn test_struct_array_from_mismatched_nullability() {
746 drop(StructArray::from(vec![(
747 Arc::new(Field::new("c", DataType::Int32, false)),
748 Arc::new(Int32Array::from(vec![Some(42), None, Some(19)])) as ArrayRef,
749 )]));
750 }
751
752 #[test]
753 fn test_struct_array_fmt_debug() {
754 let arr: StructArray = StructArray::new(
755 vec![Arc::new(Field::new("c", DataType::Int32, true))].into(),
756 vec![Arc::new(Int32Array::from((0..30).collect::<Vec<_>>())) as ArrayRef],
757 Some(NullBuffer::new(BooleanBuffer::from(
758 (0..30).map(|i| i % 2 == 0).collect::<Vec<_>>(),
759 ))),
760 );
761 assert_eq!(format!("{arr:?}"), "StructArray\n-- validity: \n[\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n ...10 elements...,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n]\n[\n-- child 0: \"c\" (Int32)\nPrimitiveArray<Int32>\n[\n 0,\n 1,\n 2,\n 3,\n 4,\n 5,\n 6,\n 7,\n 8,\n 9,\n ...10 elements...,\n 20,\n 21,\n 22,\n 23,\n 24,\n 25,\n 26,\n 27,\n 28,\n 29,\n]\n]")
762 }
763}