1use dora_message::{
2 arrow_data::ArrayData,
3 arrow_schema::DataType,
4 metadata::{ArrowTypeInfo, BufferOffset},
5};
6use eyre::Context;
7
8pub trait ArrowTypeInfoExt {
9 fn empty() -> Self;
10 fn byte_array(data_len: usize) -> Self;
11
12 unsafe fn from_array(
17 array: &ArrayData,
18 region_start: *const u8,
19 region_len: usize,
20 ) -> eyre::Result<Self>
21 where
22 Self: Sized;
23}
24
25impl ArrowTypeInfoExt for ArrowTypeInfo {
26 fn empty() -> Self {
27 Self {
28 data_type: DataType::Null,
29 len: 0,
30 null_count: 0,
31 validity: None,
32 offset: 0,
33 buffer_offsets: Vec::new(),
34 child_data: Vec::new(),
35 }
36 }
37
38 fn byte_array(data_len: usize) -> Self {
39 Self {
40 data_type: DataType::UInt8,
41 len: data_len,
42 null_count: 0,
43 validity: None,
44 offset: 0,
45 buffer_offsets: vec![BufferOffset {
46 offset: 0,
47 len: data_len,
48 }],
49 child_data: Vec::new(),
50 }
51 }
52
53 unsafe fn from_array(
54 array: &ArrayData,
55 region_start: *const u8,
56 region_len: usize,
57 ) -> eyre::Result<Self> {
58 Ok(Self {
59 data_type: array.data_type().clone(),
60 len: array.len(),
61 null_count: array.null_count(),
62 validity: array.nulls().map(|b| b.validity().to_owned()),
63 offset: array.offset(),
64 buffer_offsets: array
65 .buffers()
66 .iter()
67 .map(|b| {
68 let ptr = b.as_ptr();
69 if ptr as usize <= region_start as usize {
70 eyre::bail!("ptr {ptr:p} starts before region {region_start:p}");
71 }
72 if ptr as usize >= region_start as usize + region_len {
73 eyre::bail!("ptr {ptr:p} starts after region {region_start:p}");
74 }
75 if ptr as usize + b.len() > region_start as usize + region_len {
76 eyre::bail!("ptr {ptr:p} ends after region {region_start:p}");
77 }
78 let offset = usize::try_from(unsafe { ptr.offset_from(region_start) })
79 .context("offset_from is negative")?;
80
81 Result::<_, eyre::Report>::Ok(BufferOffset {
82 offset,
83 len: b.len(),
84 })
85 })
86 .collect::<Result<_, _>>()?,
87 child_data: array
88 .child_data()
89 .iter()
90 .map(|c| unsafe { Self::from_array(c, region_start, region_len) })
91 .collect::<Result<_, _>>()?,
92 })
93 }
94}