1use crate::traits::message::ArrowMessage;
2
3pub fn make_union_fields(
4 name: impl Into<String>,
5 fields: Vec<arrow::datatypes::Field>,
6) -> arrow::datatypes::Field {
7 arrow::datatypes::Field::new(
8 name,
9 arrow::datatypes::DataType::Union(
10 arrow::datatypes::UnionFields::new(0..fields.len() as i8, fields),
11 arrow::datatypes::UnionMode::Dense,
12 ),
13 false,
14 )
15}
16
17pub fn unpack_union(
18 data: arrow::array::ArrayData,
19) -> (
20 std::collections::HashMap<String, usize>,
21 Vec<arrow::array::ArrayRef>,
22) {
23 let (fields, _, _, children) = arrow::array::UnionArray::from(data).into_parts();
24
25 let map = fields
26 .iter()
27 .map(|(id, field)| (field.name().into(), id as usize))
28 .collect::<std::collections::HashMap<String, usize>>();
29
30 (map, children)
31}
32
33pub fn extract_union_data<T: ArrowMessage>(
34 field: &str,
35 map: &std::collections::HashMap<String, usize>,
36 children: &[arrow::array::ArrayRef],
37) -> arrow::error::Result<T> {
38 use arrow::array::Array;
39
40 T::try_from_arrow(
41 children
42 .get(
43 *map.get(field)
44 .ok_or(arrow::error::ArrowError::InvalidArgumentError(format!(
45 "Field {} not found",
46 field
47 )))?,
48 )
49 .ok_or(arrow::error::ArrowError::InvalidArgumentError(format!(
50 "Field {} not found",
51 field
52 )))?
53 .into_data(),
54 )
55}
56
57pub fn get_union_fields<T: ArrowMessage>() -> arrow::error::Result<arrow::datatypes::UnionFields> {
58 match T::field("").data_type() {
59 arrow::datatypes::DataType::Union(fields, _) => Ok(fields.clone()),
60 _ => Err(arrow::error::ArrowError::InvalidArgumentError(
61 "Expected Union data type".to_string(),
62 )),
63 }
64}
65
66pub fn make_union_array(
67 union_fields: arrow::datatypes::UnionFields,
68 children: Vec<arrow::array::ArrayRef>,
69) -> Result<arrow::array::ArrayRef, arrow::error::ArrowError> {
70 arrow::array::UnionArray::try_new(
71 union_fields,
72 arrow::buffer::ScalarBuffer::from(vec![]),
73 Some(arrow::buffer::ScalarBuffer::from(vec![])),
74 children,
75 )
76 .map(|union| std::sync::Arc::new(union) as arrow::array::ArrayRef)
77}