polars_arrow/array/struct_/
mod.rs1use super::{Array, Splitable, new_empty_array, new_null_array};
2use crate::bitmap::Bitmap;
3use crate::datatypes::{ArrowDataType, Field};
4
5mod builder;
6pub use builder::*;
7mod ffi;
8pub(super) mod fmt;
9mod iterator;
10use polars_error::{PolarsResult, polars_bail, polars_ensure};
11
12#[derive(Clone)]
29pub struct StructArray {
30 dtype: ArrowDataType,
31 values: Vec<Box<dyn Array>>,
33 length: usize,
35 validity: Option<Bitmap>,
36}
37
38impl StructArray {
39 pub fn try_new(
49 dtype: ArrowDataType,
50 length: usize,
51 values: Vec<Box<dyn Array>>,
52 validity: Option<Bitmap>,
53 ) -> PolarsResult<Self> {
54 let fields = Self::try_get_fields(&dtype)?;
55
56 polars_ensure!(
57 fields.len() == values.len(),
58 ComputeError:
59 "a StructArray must have a number of fields in its DataType equal to the number of child values"
60 );
61
62 fields
63 .iter().map(|a| &a.dtype)
64 .zip(values.iter().map(|a| a.dtype()))
65 .enumerate()
66 .try_for_each(|(index, (dtype, child))| {
67 if dtype != child {
68 polars_bail!(ComputeError:
69 "The children DataTypes of a StructArray must equal the children data types.
70 However, the field {index} has data type {dtype:?} but the value has data type {child:?}"
71 )
72 } else {
73 Ok(())
74 }
75 })?;
76
77 values
78 .iter()
79 .map(|f| f.len())
80 .enumerate()
81 .try_for_each(|(index, f_length)| {
82 if f_length != length {
83 polars_bail!(ComputeError: "The children must have the given number of values.
84 However, the values at index {index} have a length of {f_length}, which is different from given length {length}.")
85 } else {
86 Ok(())
87 }
88 })?;
89
90 if validity
91 .as_ref()
92 .is_some_and(|validity| validity.len() != length)
93 {
94 polars_bail!(ComputeError:"The validity length of a StructArray must match its number of elements")
95 }
96
97 Ok(Self {
98 dtype,
99 length,
100 values,
101 validity,
102 })
103 }
104
105 pub fn new(
115 dtype: ArrowDataType,
116 length: usize,
117 values: Vec<Box<dyn Array>>,
118 validity: Option<Bitmap>,
119 ) -> Self {
120 Self::try_new(dtype, length, values, validity).unwrap()
121 }
122
123 pub fn new_empty(dtype: ArrowDataType) -> Self {
125 if let ArrowDataType::Struct(fields) = &dtype.to_logical_type() {
126 let values = fields
127 .iter()
128 .map(|field| new_empty_array(field.dtype().clone()))
129 .collect();
130 Self::new(dtype, 0, values, None)
131 } else {
132 panic!("StructArray must be initialized with DataType::Struct");
133 }
134 }
135
136 pub fn new_null(dtype: ArrowDataType, length: usize) -> Self {
138 if let ArrowDataType::Struct(fields) = &dtype {
139 let values = fields
140 .iter()
141 .map(|field| new_null_array(field.dtype().clone(), length))
142 .collect();
143 Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length)))
144 } else {
145 panic!("StructArray must be initialized with DataType::Struct");
146 }
147 }
148}
149
150impl StructArray {
152 #[must_use]
154 pub fn into_data(self) -> (Vec<Field>, usize, Vec<Box<dyn Array>>, Option<Bitmap>) {
155 let Self {
156 dtype,
157 length,
158 values,
159 validity,
160 } = self;
161 let fields = if let ArrowDataType::Struct(fields) = dtype {
162 fields
163 } else {
164 unreachable!()
165 };
166 (fields, length, values, validity)
167 }
168
169 pub fn slice(&mut self, offset: usize, length: usize) {
175 assert!(
176 offset + length <= self.len(),
177 "offset + length may not exceed length of array"
178 );
179 unsafe { self.slice_unchecked(offset, length) }
180 }
181
182 pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
189 self.validity = self
190 .validity
191 .take()
192 .map(|bitmap| bitmap.sliced_unchecked(offset, length))
193 .filter(|bitmap| bitmap.unset_bits() > 0);
194 self.values
195 .iter_mut()
196 .for_each(|x| x.slice_unchecked(offset, length));
197 self.length = length;
198 }
199
200 impl_sliced!();
201
202 impl_mut_validity!();
203
204 impl_into_array!();
205}
206
207impl StructArray {
209 #[inline]
210 pub fn len(&self) -> usize {
211 if cfg!(debug_assertions) {
212 for arr in self.values.iter() {
213 assert_eq!(
214 arr.len(),
215 self.length,
216 "StructArray invariant: each array has same length"
217 );
218 }
219 }
220
221 self.length
222 }
223
224 #[inline]
226 pub fn validity(&self) -> Option<&Bitmap> {
227 self.validity.as_ref()
228 }
229
230 pub fn values(&self) -> &[Box<dyn Array>] {
232 &self.values
233 }
234
235 pub fn fields(&self) -> &[Field] {
237 let fields = Self::get_fields(&self.dtype);
238 debug_assert_eq!(self.values().len(), fields.len());
239 fields
240 }
241}
242
243impl StructArray {
244 pub(crate) fn try_get_fields(dtype: &ArrowDataType) -> PolarsResult<&[Field]> {
246 match dtype.to_logical_type() {
247 ArrowDataType::Struct(fields) => Ok(fields),
248 _ => {
249 polars_bail!(ComputeError: "Struct array must be created with a DataType whose physical type is Struct")
250 },
251 }
252 }
253
254 pub fn get_fields(dtype: &ArrowDataType) -> &[Field] {
256 Self::try_get_fields(dtype).unwrap()
257 }
258}
259
260impl Array for StructArray {
261 impl_common_array!();
262
263 fn validity(&self) -> Option<&Bitmap> {
264 self.validity.as_ref()
265 }
266
267 #[inline]
268 fn with_validity(&self, validity: Option<Bitmap>) -> Box<dyn Array> {
269 Box::new(self.clone().with_validity(validity))
270 }
271}
272
273impl Splitable for StructArray {
274 fn check_bound(&self, offset: usize) -> bool {
275 offset <= self.len()
276 }
277
278 unsafe fn _split_at_unchecked(&self, offset: usize) -> (Self, Self) {
279 let (lhs_validity, rhs_validity) = unsafe { self.validity.split_at_unchecked(offset) };
280
281 let mut lhs_values = Vec::with_capacity(self.values.len());
282 let mut rhs_values = Vec::with_capacity(self.values.len());
283
284 for v in self.values.iter() {
285 let (lhs, rhs) = unsafe { v.split_at_boxed_unchecked(offset) };
286 lhs_values.push(lhs);
287 rhs_values.push(rhs);
288 }
289
290 (
291 Self {
292 dtype: self.dtype.clone(),
293 length: offset,
294 values: lhs_values,
295 validity: lhs_validity,
296 },
297 Self {
298 dtype: self.dtype.clone(),
299 length: self.length - offset,
300 values: rhs_values,
301 validity: rhs_validity,
302 },
303 )
304 }
305}