1use arrow_array::{cast::AsArray, make_array, Array, StructArray};
7use arrow_buffer::NullBuffer;
8use arrow_data::{ArrayData, ArrayDataBuilder};
9use arrow_schema::ArrowError;
10
11pub trait StructArrayExt {
12 fn normalize_slicing(&self) -> Result<Self, ArrowError>
22 where
23 Self: Sized;
24
25 fn pushdown_nulls(&self) -> Result<Self, ArrowError>
44 where
45 Self: Sized;
46}
47
48fn normalized_struct_array_data(data: ArrayData) -> Result<ArrayData, ArrowError> {
49 let parent_offset = data.offset();
50 let parent_len = data.len();
51 let modified_children = data
52 .child_data()
53 .iter()
54 .map(|d| {
55 let d = normalized_struct_array_data(d.clone())?;
56 let offset = d.offset();
57 let len = d.len();
58 if len < parent_len + parent_offset {
59 return Err(ArrowError::InvalidArgumentError(format!(
60 "Child array {} has length {} which is less than the parent length {} plus the parent offset {}",
61 d.data_type(),
62 len,
63 parent_len,
64 parent_offset
65 )));
66 }
67 let new_offset = offset + parent_offset;
68 d.into_builder().offset(new_offset)
69 .len(parent_len)
70 .build()
71 })
72 .collect::<Result<Vec<_>, _>>()?;
73 ArrayDataBuilder::new(data.data_type().clone())
74 .len(parent_len)
75 .offset(0)
76 .buffers(data.buffers().to_vec())
77 .child_data(modified_children)
78 .build()
79}
80
81impl StructArrayExt for StructArray {
82 fn normalize_slicing(&self) -> Result<Self, ArrowError>
83 where
84 Self: Sized,
85 {
86 if self.offset() == 0 && self.columns().iter().all(|c| c.len() == self.len()) {
87 return Ok(self.clone());
88 }
89
90 let data = normalized_struct_array_data(self.to_data())?;
91 Ok(Self::from(data))
92 }
93
94 fn pushdown_nulls(&self) -> Result<Self, ArrowError>
95 where
96 Self: Sized,
97 {
98 let Some(validity) = self.nulls() else {
99 return Ok(self.clone());
100 };
101 let data = self.to_data();
102 let children = data
103 .child_data()
104 .iter()
105 .map(|c| {
106 if let Some(child_validity) = c.nulls() {
107 let new_validity = child_validity.inner() & validity.inner();
108 c.clone()
109 .into_builder()
110 .nulls(Some(NullBuffer::from(new_validity)))
111 .build()
112 } else {
113 Ok(c.clone()
114 .into_builder()
115 .nulls(Some(validity.clone()))
116 .build()?)
117 }
118 })
119 .collect::<Result<Vec<_>, _>>()?;
120 let arr = make_array(data.into_builder().child_data(children).build()?);
121 Ok(arr.as_struct().clone())
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use arrow_array::{cast::AsArray, make_array, Array, Int32Array, StructArray};
128 use arrow_schema::{DataType, Field, Fields};
129 use std::sync::Arc;
130
131 use crate::r#struct::StructArrayExt;
132
133 #[test]
134 fn test_normalize_slicing_no_offset() {
135 let x = Int32Array::from(vec![1, 2, 3]);
136 let y = Int32Array::from(vec![4, 5, 6]);
137 let struct_array = StructArray::new(
138 Fields::from(vec![
139 Field::new("x", DataType::Int32, true),
140 Field::new("y", DataType::Int32, true),
141 ]),
142 vec![Arc::new(x), Arc::new(y)],
143 None,
144 );
145
146 let normalized = struct_array.normalize_slicing().unwrap();
147 assert_eq!(normalized, struct_array);
148 }
149
150 #[test]
151 fn test_arrow_rs_slicing() {
152 let x = Int32Array::from(vec![1, 2, 3, 4]);
153 let y = Int32Array::from(vec![5, 6, 7, 8]);
154 let struct_array = StructArray::new(
155 Fields::from(vec![
156 Field::new("x", DataType::Int32, true),
157 Field::new("y", DataType::Int32, true),
158 ]),
159 vec![Arc::new(x), Arc::new(y)],
160 None,
161 );
162
163 let sliced = struct_array.slice(1, 2);
166 let normalized = sliced.normalize_slicing().unwrap();
167
168 assert_eq!(normalized, sliced);
169 }
170
171 #[test]
172 fn test_arrow_cpp_slicing() {
173 let x = Int32Array::from(vec![1, 2, 3, 4]);
174 let y = Int32Array::from(vec![5, 6, 7, 8]);
175 let struct_array = StructArray::new(
176 Fields::from(vec![
177 Field::new("x", DataType::Int32, true),
178 Field::new("y", DataType::Int32, true),
179 ]),
180 vec![Arc::new(x), Arc::new(y)],
181 None,
182 );
183
184 let data = struct_array.to_data();
185 let sliced = data.into_builder().offset(1).len(2).build().unwrap();
186 let sliced = make_array(sliced);
187 let normalized = sliced.as_struct().clone().normalize_slicing().unwrap();
188
189 assert_eq!(normalized, struct_array.slice(1, 2));
190 }
191}