1use crate::traits::ArrowMessage;
2
3pub fn make_union_fields(
4 name: impl Into<String>,
5 fields: Vec<arrow::datatypes::Field>,
6) -> arrow::datatypes::Field {
7 arrow::datatypes::Field::new(
8 name,
9 arrow::datatypes::DataType::Union(
10 arrow::datatypes::UnionFields::new(0..fields.len() as i8, fields),
11 arrow::datatypes::UnionMode::Dense,
12 ),
13 false,
14 )
15}
16
17pub fn unpack_union(
18 data: arrow::array::ArrayData,
19) -> (
20 std::collections::HashMap<String, usize>,
21 Vec<arrow::array::ArrayRef>,
22) {
23 let (fields, _, _, children) = arrow::array::UnionArray::from(data).into_parts();
24
25 let map = fields
26 .iter()
27 .map(|(id, field)| (field.name().into(), id as usize))
28 .collect::<std::collections::HashMap<String, usize>>();
29
30 (map, children)
31}
32
33pub fn extract_union_data<T: ArrowMessage>(
34 field: &str,
35 map: &std::collections::HashMap<String, usize>,
36 children: &[arrow::array::ArrayRef],
37) -> miette::Result<T, arrow::error::ArrowError> {
38 use arrow::array::Array;
39
40 T::try_from_arrow(
41 children
42 .get(
43 *map.get(field)
44 .ok_or(arrow::error::ArrowError::InvalidArgumentError(format!(
45 "Field {} not found",
46 field
47 )))?,
48 )
49 .ok_or(arrow::error::ArrowError::InvalidArgumentError(format!(
50 "Field {} not found",
51 field
52 )))?
53 .into_data(),
54 )
55}
56
57pub fn get_union_fields<T: ArrowMessage>()
58-> miette::Result<arrow::datatypes::UnionFields, arrow::error::ArrowError> {
59 match T::field("").data_type() {
60 arrow::datatypes::DataType::Union(fields, _) => Ok(fields.clone()),
61 _ => Err(arrow::error::ArrowError::InvalidArgumentError(
62 "Expected Union data type".to_string(),
63 )),
64 }
65}
66
67pub fn make_union_array(
68 union_fields: arrow::datatypes::UnionFields,
69 children: Vec<arrow::array::ArrayRef>,
70) -> Result<arrow::array::ArrayRef, arrow::error::ArrowError> {
71 arrow::array::UnionArray::try_new(
72 union_fields,
73 arrow::buffer::ScalarBuffer::from(vec![]),
74 Some(arrow::buffer::ScalarBuffer::from(vec![])),
75 children,
76 )
77 .map(|union| std::sync::Arc::new(union) as arrow::array::ArrayRef)
78}