1use std::sync::{Arc, OnceLock};
4
5use arrow::array::{Array, ArrayRef, Float64Array, Int64Array, StringArray, StructArray};
6use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
7use arrow::record_batch::RecordBatch;
8
9#[allow(missing_docs)]
11pub struct Imu<'a> {
12 pub frame_id: &'a str,
13 pub lin_acc_x: f64,
14 pub lin_acc_y: f64,
15 pub lin_acc_z: f64,
16 pub ang_vel_x: f64,
17 pub ang_vel_y: f64,
18 pub ang_vel_z: f64,
19 pub orientation_w: f64,
20 pub orientation_x: f64,
21 pub orientation_y: f64,
22 pub orientation_z: f64,
23 pub timestamp_ns: i64,
24}
25
26#[derive(Debug, Clone, PartialEq)]
28#[allow(missing_docs)]
29pub struct ImuOwned {
30 pub frame_id: String,
31 pub lin_acc_x: f64,
32 pub lin_acc_y: f64,
33 pub lin_acc_z: f64,
34 pub ang_vel_x: f64,
35 pub ang_vel_y: f64,
36 pub ang_vel_z: f64,
37 pub orientation_w: f64,
38 pub orientation_x: f64,
39 pub orientation_y: f64,
40 pub orientation_z: f64,
41 pub timestamp_ns: i64,
42}
43
44pub fn schema() -> SchemaRef {
46 static SCHEMA: OnceLock<SchemaRef> = OnceLock::new();
47 SCHEMA
48 .get_or_init(|| {
49 Arc::new(Schema::new(vec![
50 Field::new("frame_id", DataType::Utf8, false),
51 Field::new("lin_acc_x", DataType::Float64, false),
52 Field::new("lin_acc_y", DataType::Float64, false),
53 Field::new("lin_acc_z", DataType::Float64, false),
54 Field::new("ang_vel_x", DataType::Float64, false),
55 Field::new("ang_vel_y", DataType::Float64, false),
56 Field::new("ang_vel_z", DataType::Float64, false),
57 Field::new("orientation_w", DataType::Float64, false),
58 Field::new("orientation_x", DataType::Float64, false),
59 Field::new("orientation_y", DataType::Float64, false),
60 Field::new("orientation_z", DataType::Float64, false),
61 Field::new("timestamp_ns", DataType::Int64, false),
62 ]))
63 })
64 .clone()
65}
66
67pub fn to_record_batch(imu: &Imu) -> Result<RecordBatch, arrow::error::ArrowError> {
85 let columns: Vec<ArrayRef> = vec![
86 Arc::new(StringArray::from(vec![imu.frame_id])),
87 Arc::new(Float64Array::from_iter_values(std::iter::once(
88 imu.lin_acc_x,
89 ))),
90 Arc::new(Float64Array::from_iter_values(std::iter::once(
91 imu.lin_acc_y,
92 ))),
93 Arc::new(Float64Array::from_iter_values(std::iter::once(
94 imu.lin_acc_z,
95 ))),
96 Arc::new(Float64Array::from_iter_values(std::iter::once(
97 imu.ang_vel_x,
98 ))),
99 Arc::new(Float64Array::from_iter_values(std::iter::once(
100 imu.ang_vel_y,
101 ))),
102 Arc::new(Float64Array::from_iter_values(std::iter::once(
103 imu.ang_vel_z,
104 ))),
105 Arc::new(Float64Array::from_iter_values(std::iter::once(
106 imu.orientation_w,
107 ))),
108 Arc::new(Float64Array::from_iter_values(std::iter::once(
109 imu.orientation_x,
110 ))),
111 Arc::new(Float64Array::from_iter_values(std::iter::once(
112 imu.orientation_y,
113 ))),
114 Arc::new(Float64Array::from_iter_values(std::iter::once(
115 imu.orientation_z,
116 ))),
117 Arc::new(Int64Array::from_iter_values(std::iter::once(
118 imu.timestamp_ns,
119 ))),
120 ];
121 RecordBatch::try_new(schema(), columns)
122}
123
124pub fn from_struct_array(array: &StructArray) -> Result<ImuOwned, arrow::error::ArrowError> {
145 if array.is_empty() {
146 return Err(arrow::error::ArrowError::InvalidArgumentError(
147 "imu struct array is empty".into(),
148 ));
149 }
150 let f64_at = |idx: usize, name: &str| -> Result<f64, arrow::error::ArrowError> {
151 array
152 .column(idx)
153 .as_any()
154 .downcast_ref::<Float64Array>()
155 .ok_or_else(|| {
156 arrow::error::ArrowError::SchemaError(format!("imu '{name}' not Float64"))
157 })
158 .map(|a| a.value(0))
159 };
160 let frame_id = array
161 .column(0)
162 .as_any()
163 .downcast_ref::<StringArray>()
164 .ok_or_else(|| arrow::error::ArrowError::SchemaError("imu 'frame_id' not Utf8".into()))?
165 .value(0)
166 .to_string();
167 let timestamp_ns = array
168 .column(11)
169 .as_any()
170 .downcast_ref::<Int64Array>()
171 .ok_or_else(|| {
172 arrow::error::ArrowError::SchemaError("imu 'timestamp_ns' not Int64".into())
173 })?
174 .value(0);
175 Ok(ImuOwned {
176 frame_id,
177 lin_acc_x: f64_at(1, "lin_acc_x")?,
178 lin_acc_y: f64_at(2, "lin_acc_y")?,
179 lin_acc_z: f64_at(3, "lin_acc_z")?,
180 ang_vel_x: f64_at(4, "ang_vel_x")?,
181 ang_vel_y: f64_at(5, "ang_vel_y")?,
182 ang_vel_z: f64_at(6, "ang_vel_z")?,
183 orientation_w: f64_at(7, "orientation_w")?,
184 orientation_x: f64_at(8, "orientation_x")?,
185 orientation_y: f64_at(9, "orientation_y")?,
186 orientation_z: f64_at(10, "orientation_z")?,
187 timestamp_ns,
188 })
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use arrow::array::Array;
195
196 #[test]
197 fn round_trips_through_record_batch() {
198 let imu = Imu {
199 frame_id: "sim_imu",
200 lin_acc_x: 0.1,
201 lin_acc_y: 0.2,
202 lin_acc_z: 9.81,
203 ang_vel_x: 0.0,
204 ang_vel_y: 0.0,
205 ang_vel_z: 0.5,
206 orientation_w: 1.0,
207 orientation_x: 0.0,
208 orientation_y: 0.0,
209 orientation_z: 0.0,
210 timestamp_ns: 12345,
211 };
212 let batch = to_record_batch(&imu).expect("convert");
213 assert_eq!(batch.num_rows(), 1);
214 assert_eq!(batch.num_columns(), 12);
215
216 let frame = batch
217 .column(0)
218 .as_any()
219 .downcast_ref::<StringArray>()
220 .expect("frame_id is Utf8");
221 assert_eq!(frame.value(0), "sim_imu");
222
223 let lin_z = batch
224 .column(3)
225 .as_any()
226 .downcast_ref::<Float64Array>()
227 .expect("lin_acc_z is Float64");
228 assert!((lin_z.value(0) - 9.81).abs() < 1e-9);
229
230 let ts = batch
231 .column(11)
232 .as_any()
233 .downcast_ref::<Int64Array>()
234 .expect("timestamp_ns is Int64");
235 assert_eq!(ts.value(0), 12345);
236 }
237
238 #[test]
239 fn from_struct_array_round_trips() {
240 let imu = Imu {
241 frame_id: "sim_imu",
242 lin_acc_x: 0.1,
243 lin_acc_y: 0.2,
244 lin_acc_z: 9.81,
245 ang_vel_x: 0.0,
246 ang_vel_y: 0.0,
247 ang_vel_z: 0.5,
248 orientation_w: 1.0,
249 orientation_x: 0.0,
250 orientation_y: 0.0,
251 orientation_z: 0.0,
252 timestamp_ns: 12345,
253 };
254 let batch = to_record_batch(&imu).expect("to");
255 let array = StructArray::from(batch);
256 let owned = from_struct_array(&array).expect("from");
257 assert_eq!(owned.frame_id, "sim_imu");
258 assert!((owned.lin_acc_z - 9.81).abs() < 1e-9);
259 assert_eq!(owned.timestamp_ns, 12345);
260 }
261}