flarrow_message/
helper.rs1use std::{collections::HashMap, sync::Arc};
2
3use arrow_buffer::ScalarBuffer;
4use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
5
6use arrow_array::{Array, ArrayRef, UnionArray};
7use arrow_data::ArrayData;
8
9use crate::prelude::*;
10
11pub fn make_union_fields(name: impl Into<String>, fields: Vec<Field>) -> Field {
12 Field::new(
13 name,
14 DataType::Union(
15 UnionFields::new(0..fields.len() as i8, fields),
16 UnionMode::Dense,
17 ),
18 false,
19 )
20}
21
22pub fn unpack_union(data: ArrayData) -> (HashMap<String, usize>, Vec<ArrayRef>) {
23 let (fields, _, _, children) = UnionArray::from(data).into_parts();
24
25 let map = fields
26 .iter()
27 .map(|(id, field)| (field.name().into(), id as usize))
28 .collect::<HashMap<String, usize>>();
29
30 (map, children)
31}
32
33pub fn extract_union_data<T: ArrowMessage>(
34 field: &str,
35 map: &HashMap<String, usize>,
36 children: &[ArrayRef],
37) -> ArrowResult<T> {
38 T::try_from_arrow(
39 children
40 .get(
41 *map.get(field)
42 .ok_or(ArrowError::InvalidArgumentError(format!(
43 "Field {} not found",
44 field
45 )))?,
46 )
47 .ok_or(ArrowError::InvalidArgumentError(format!(
48 "Field {} not found",
49 field
50 )))?
51 .into_data(),
52 )
53}
54
55pub fn get_union_fields<T: ArrowMessage>() -> ArrowResult<UnionFields> {
56 match T::field("").data_type() {
57 DataType::Union(fields, _) => Ok(fields.clone()),
58 _ => Err(ArrowError::InvalidArgumentError(
59 "Expected Union data type".to_string(),
60 )),
61 }
62}
63
64pub fn make_union_array(
65 union_fields: UnionFields,
66 children: Vec<ArrayRef>,
67) -> ArrowResult<ArrayRef> {
68 UnionArray::try_new(
69 union_fields,
70 ScalarBuffer::from(vec![]),
71 Some(ScalarBuffer::from(vec![])),
72 children,
73 )
74 .map(|union| Arc::new(union) as ArrayRef)
75}