1use crate::builder::buffer_builder::{Int8BufferBuilder, Int32BufferBuilder};
19use crate::builder::{ArrayBuilder, BufferBuilder};
20use crate::{ArrayRef, ArrowPrimitiveType, UnionArray, make_array};
21use arrow_buffer::NullBufferBuilder;
22use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer};
23use arrow_data::ArrayDataBuilder;
24use arrow_schema::{ArrowError, DataType, Field};
25use std::any::Any;
26use std::collections::BTreeMap;
27use std::sync::Arc;
28
29#[derive(Debug)]
31struct FieldData {
32    type_id: i8,
34    data_type: DataType,
36    values_buffer: Box<dyn FieldDataValues>,
38    slots: usize,
40    null_buffer_builder: NullBufferBuilder,
42}
43
44trait FieldDataValues: std::fmt::Debug + Send + Sync {
46    fn as_mut_any(&mut self) -> &mut dyn Any;
47
48    fn append_null(&mut self);
49
50    fn finish(&mut self) -> Buffer;
51
52    fn finish_cloned(&self) -> Buffer;
53}
54
55impl<T: ArrowNativeType> FieldDataValues for BufferBuilder<T> {
56    fn as_mut_any(&mut self) -> &mut dyn Any {
57        self
58    }
59
60    fn append_null(&mut self) {
61        self.advance(1)
62    }
63
64    fn finish(&mut self) -> Buffer {
65        self.finish()
66    }
67
68    fn finish_cloned(&self) -> Buffer {
69        Buffer::from_slice_ref(self.as_slice())
70    }
71}
72
73impl FieldData {
74    fn new<T: ArrowPrimitiveType>(type_id: i8, data_type: DataType, capacity: usize) -> Self {
76        Self {
77            type_id,
78            data_type,
79            slots: 0,
80            values_buffer: Box::new(BufferBuilder::<T::Native>::new(capacity)),
81            null_buffer_builder: NullBufferBuilder::new(capacity),
82        }
83    }
84
85    fn append_value<T: ArrowPrimitiveType>(&mut self, v: T::Native) {
87        self.values_buffer
88            .as_mut_any()
89            .downcast_mut::<BufferBuilder<T::Native>>()
90            .expect("Tried to append unexpected type")
91            .append(v);
92
93        self.null_buffer_builder.append(true);
94        self.slots += 1;
95    }
96
97    fn append_null(&mut self) {
99        self.values_buffer.append_null();
100        self.null_buffer_builder.append(false);
101        self.slots += 1;
102    }
103}
104
105#[derive(Debug, Default)]
148pub struct UnionBuilder {
149    len: usize,
151    fields: BTreeMap<String, FieldData>,
153    type_id_builder: Int8BufferBuilder,
155    value_offset_builder: Option<Int32BufferBuilder>,
157    initial_capacity: usize,
158}
159
160impl UnionBuilder {
161    pub fn new_dense() -> Self {
163        Self::with_capacity_dense(1024)
164    }
165
166    pub fn new_sparse() -> Self {
168        Self::with_capacity_sparse(1024)
169    }
170
171    pub fn with_capacity_dense(capacity: usize) -> Self {
173        Self {
174            len: 0,
175            fields: Default::default(),
176            type_id_builder: Int8BufferBuilder::new(capacity),
177            value_offset_builder: Some(Int32BufferBuilder::new(capacity)),
178            initial_capacity: capacity,
179        }
180    }
181
182    pub fn with_capacity_sparse(capacity: usize) -> Self {
184        Self {
185            len: 0,
186            fields: Default::default(),
187            type_id_builder: Int8BufferBuilder::new(capacity),
188            value_offset_builder: None,
189            initial_capacity: capacity,
190        }
191    }
192
193    #[inline]
201    pub fn append_null<T: ArrowPrimitiveType>(
202        &mut self,
203        type_name: &str,
204    ) -> Result<(), ArrowError> {
205        self.append_option::<T>(type_name, None)
206    }
207
208    #[inline]
210    pub fn append<T: ArrowPrimitiveType>(
211        &mut self,
212        type_name: &str,
213        v: T::Native,
214    ) -> Result<(), ArrowError> {
215        self.append_option::<T>(type_name, Some(v))
216    }
217
218    fn append_option<T: ArrowPrimitiveType>(
219        &mut self,
220        type_name: &str,
221        v: Option<T::Native>,
222    ) -> Result<(), ArrowError> {
223        let type_name = type_name.to_string();
224
225        let mut field_data = match self.fields.remove(&type_name) {
226            Some(data) => {
227                if data.data_type != T::DATA_TYPE {
228                    return Err(ArrowError::InvalidArgumentError(format!(
229                        "Attempt to write col \"{}\" with type {} doesn't match existing type {}",
230                        type_name,
231                        T::DATA_TYPE,
232                        data.data_type
233                    )));
234                }
235                data
236            }
237            None => match self.value_offset_builder {
238                Some(_) => FieldData::new::<T>(
239                    self.fields.len() as i8,
240                    T::DATA_TYPE,
241                    self.initial_capacity,
242                ),
243                None => {
245                    let mut fd = FieldData::new::<T>(
246                        self.fields.len() as i8,
247                        T::DATA_TYPE,
248                        self.len.max(self.initial_capacity),
249                    );
250                    for _ in 0..self.len {
251                        fd.append_null();
252                    }
253                    fd
254                }
255            },
256        };
257        self.type_id_builder.append(field_data.type_id);
258
259        match &mut self.value_offset_builder {
260            Some(offset_builder) => {
262                offset_builder.append(field_data.slots as i32);
263            }
264            None => {
266                for (_, fd) in self.fields.iter_mut() {
267                    fd.append_null();
269                }
270            }
271        }
272
273        match v {
274            Some(v) => field_data.append_value::<T>(v),
275            None => field_data.append_null(),
276        }
277
278        self.fields.insert(type_name, field_data);
279        self.len += 1;
280        Ok(())
281    }
282
283    pub fn build(self) -> Result<UnionArray, ArrowError> {
285        let mut children = Vec::with_capacity(self.fields.len());
286        let union_fields = self
287            .fields
288            .into_iter()
289            .map(
290                |(
291                    name,
292                    FieldData {
293                        type_id,
294                        data_type,
295                        mut values_buffer,
296                        slots,
297                        mut null_buffer_builder,
298                    },
299                )| {
300                    let array_ref = make_array(unsafe {
301                        ArrayDataBuilder::new(data_type.clone())
302                            .add_buffer(values_buffer.finish())
303                            .len(slots)
304                            .nulls(null_buffer_builder.finish())
305                            .build_unchecked()
306                    });
307                    children.push(array_ref);
308                    (type_id, Arc::new(Field::new(name, data_type, false)))
309                },
310            )
311            .collect();
312        UnionArray::try_new(
313            union_fields,
314            self.type_id_builder.into(),
315            self.value_offset_builder.map(Into::into),
316            children,
317        )
318    }
319
320    fn build_cloned(&self) -> Result<UnionArray, ArrowError> {
324        let mut children = Vec::with_capacity(self.fields.len());
325        let union_fields: Vec<_> = self
326            .fields
327            .iter()
328            .map(|(name, field_data)| {
329                let FieldData {
330                    type_id,
331                    data_type,
332                    values_buffer,
333                    slots,
334                    null_buffer_builder,
335                } = field_data;
336
337                let array_ref = make_array(unsafe {
338                    ArrayDataBuilder::new(data_type.clone())
339                        .add_buffer(values_buffer.finish_cloned())
340                        .len(*slots)
341                        .nulls(null_buffer_builder.finish_cloned())
342                        .build_unchecked()
343                });
344                children.push(array_ref);
345                (
346                    *type_id,
347                    Arc::new(Field::new(name.clone(), data_type.clone(), false)),
348                )
349            })
350            .collect();
351        UnionArray::try_new(
352            union_fields.into_iter().collect(),
353            ScalarBuffer::from(self.type_id_builder.as_slice().to_vec()),
354            self.value_offset_builder
355                .as_ref()
356                .map(|builder| ScalarBuffer::from(builder.as_slice().to_vec())),
357            children,
358        )
359    }
360}
361
362impl ArrayBuilder for UnionBuilder {
363    fn len(&self) -> usize {
365        self.len
366    }
367
368    fn finish(&mut self) -> ArrayRef {
370        let builder = std::mem::take(self);
372
373        Arc::new(builder.build().unwrap())
375    }
376
377    fn finish_cloned(&self) -> ArrayRef {
379        Arc::new(self.build_cloned().unwrap_or_else(|err| {
382            panic!("UnionBuilder::build_cloned failed unexpectedly: {}", err)
383        }))
384    }
385
386    fn as_any(&self) -> &dyn Any {
388        self
389    }
390
391    fn as_any_mut(&mut self) -> &mut dyn Any {
393        self
394    }
395
396    fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
398        self
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::array::Array;
406    use crate::cast::AsArray;
407    use crate::types::{Float64Type, Int32Type};
408
409    #[test]
410    fn test_union_builder_array_builder_trait() {
411        let mut builder = UnionBuilder::new_dense();
413
414        builder.append::<Int32Type>("a", 1).unwrap();
416        builder.append::<Float64Type>("b", 3.0).unwrap();
417        builder.append::<Int32Type>("a", 4).unwrap();
418
419        assert_eq!(builder.len(), 3);
420
421        let array1 = builder.finish_cloned();
423        assert_eq!(array1.len(), 3);
424
425        let union1 = array1.as_any().downcast_ref::<UnionArray>().unwrap();
427        assert_eq!(union1.type_ids(), &[0, 1, 0]);
428        assert_eq!(union1.offsets().unwrap().as_ref(), &[0, 0, 1]);
429        let int_array1 = union1.child(0).as_primitive::<Int32Type>();
430        let float_array1 = union1.child(1).as_primitive::<Float64Type>();
431        assert_eq!(int_array1.value(0), 1);
432        assert_eq!(int_array1.value(1), 4);
433        assert_eq!(float_array1.value(0), 3.0);
434
435        builder.append::<Float64Type>("b", 5.0).unwrap();
437        assert_eq!(builder.len(), 4);
438
439        let array2 = builder.finish();
441        assert_eq!(array2.len(), 4);
442
443        let union2 = array2.as_any().downcast_ref::<UnionArray>().unwrap();
445        assert_eq!(union2.type_ids(), &[0, 1, 0, 1]);
446        assert_eq!(union2.offsets().unwrap().as_ref(), &[0, 0, 1, 1]);
447        let int_array2 = union2.child(0).as_primitive::<Int32Type>();
448        let float_array2 = union2.child(1).as_primitive::<Float64Type>();
449        assert_eq!(int_array2.value(0), 1);
450        assert_eq!(int_array2.value(1), 4);
451        assert_eq!(float_array2.value(0), 3.0);
452        assert_eq!(float_array2.value(1), 5.0);
453    }
454
455    #[test]
456    fn test_union_builder_type_erased() {
457        let mut builders: Vec<Box<dyn ArrayBuilder>> = vec![Box::new(UnionBuilder::new_sparse())];
459
460        let union_builder = builders[0]
462            .as_any_mut()
463            .downcast_mut::<UnionBuilder>()
464            .unwrap();
465        union_builder.append::<Int32Type>("x", 10).unwrap();
466        union_builder.append::<Float64Type>("y", 20.0).unwrap();
467
468        assert_eq!(builders[0].len(), 2);
469
470        let result = builders
471            .into_iter()
472            .map(|mut b| b.finish())
473            .collect::<Vec<_>>();
474        assert_eq!(result[0].len(), 2);
475
476        let union = result[0].as_any().downcast_ref::<UnionArray>().unwrap();
478        assert_eq!(union.type_ids(), &[0, 1]);
479        assert!(union.offsets().is_none()); let int_array = union.child(0).as_primitive::<Int32Type>();
481        let float_array = union.child(1).as_primitive::<Float64Type>();
482        assert_eq!(int_array.value(0), 10);
483        assert!(int_array.is_null(1)); assert!(float_array.is_null(0)); assert_eq!(float_array.value(1), 20.0);
486    }
487}