Skip to main content

isaac_sim_arrow/
imu.rs

1// SPDX-License-Identifier: MPL-2.0
2//! Arrow encoder and decoder for the IMU sensor channel.
3use std::sync::{Arc, OnceLock};
4
5use arrow::array::{Array, ArrayRef, Float64Array, Int64Array, StringArray, StructArray};
6use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
7use arrow::record_batch::RecordBatch;
8
9/// Borrowed view of a single IMU sample, used as input to [`to_record_batch`].
10#[allow(missing_docs)]
11pub struct Imu<'a> {
12    pub frame_id: &'a str,
13    pub lin_acc_x: f64,
14    pub lin_acc_y: f64,
15    pub lin_acc_z: f64,
16    pub ang_vel_x: f64,
17    pub ang_vel_y: f64,
18    pub ang_vel_z: f64,
19    pub orientation_w: f64,
20    pub orientation_x: f64,
21    pub orientation_y: f64,
22    pub orientation_z: f64,
23    pub timestamp_ns: i64,
24}
25
26/// Owned variant returned by [`from_struct_array`].
27#[derive(Debug, Clone, PartialEq)]
28#[allow(missing_docs)]
29pub struct ImuOwned {
30    pub frame_id: String,
31    pub lin_acc_x: f64,
32    pub lin_acc_y: f64,
33    pub lin_acc_z: f64,
34    pub ang_vel_x: f64,
35    pub ang_vel_y: f64,
36    pub ang_vel_z: f64,
37    pub orientation_w: f64,
38    pub orientation_x: f64,
39    pub orientation_y: f64,
40    pub orientation_z: f64,
41    pub timestamp_ns: i64,
42}
43
44/// Stable Arrow schema for an `Imu` record batch.
45pub fn schema() -> SchemaRef {
46    static SCHEMA: OnceLock<SchemaRef> = OnceLock::new();
47    SCHEMA
48        .get_or_init(|| {
49            Arc::new(Schema::new(vec![
50                Field::new("frame_id", DataType::Utf8, false),
51                Field::new("lin_acc_x", DataType::Float64, false),
52                Field::new("lin_acc_y", DataType::Float64, false),
53                Field::new("lin_acc_z", DataType::Float64, false),
54                Field::new("ang_vel_x", DataType::Float64, false),
55                Field::new("ang_vel_y", DataType::Float64, false),
56                Field::new("ang_vel_z", DataType::Float64, false),
57                Field::new("orientation_w", DataType::Float64, false),
58                Field::new("orientation_x", DataType::Float64, false),
59                Field::new("orientation_y", DataType::Float64, false),
60                Field::new("orientation_z", DataType::Float64, false),
61                Field::new("timestamp_ns", DataType::Int64, false),
62            ]))
63        })
64        .clone()
65}
66
67/// Encode an `Imu` sample as a single-row `RecordBatch` matching [`schema`].
68///
69/// # Example
70///
71/// ```
72/// use isaac_sim_arrow::imu::{Imu, to_record_batch};
73/// let sample = Imu {
74///     frame_id: "imu",
75///     lin_acc_x: 0.0, lin_acc_y: 0.0, lin_acc_z: 9.81,
76///     ang_vel_x: 0.0, ang_vel_y: 0.0, ang_vel_z: 0.0,
77///     orientation_w: 1.0, orientation_x: 0.0, orientation_y: 0.0, orientation_z: 0.0,
78///     timestamp_ns: 1_000_000,
79/// };
80/// let batch = to_record_batch(&sample).unwrap();
81/// assert_eq!(batch.num_rows(), 1);
82/// assert_eq!(batch.num_columns(), 12);
83/// ```
84pub fn to_record_batch(imu: &Imu) -> Result<RecordBatch, arrow::error::ArrowError> {
85    let columns: Vec<ArrayRef> = vec![
86        Arc::new(StringArray::from(vec![imu.frame_id])),
87        Arc::new(Float64Array::from_iter_values(std::iter::once(
88            imu.lin_acc_x,
89        ))),
90        Arc::new(Float64Array::from_iter_values(std::iter::once(
91            imu.lin_acc_y,
92        ))),
93        Arc::new(Float64Array::from_iter_values(std::iter::once(
94            imu.lin_acc_z,
95        ))),
96        Arc::new(Float64Array::from_iter_values(std::iter::once(
97            imu.ang_vel_x,
98        ))),
99        Arc::new(Float64Array::from_iter_values(std::iter::once(
100            imu.ang_vel_y,
101        ))),
102        Arc::new(Float64Array::from_iter_values(std::iter::once(
103            imu.ang_vel_z,
104        ))),
105        Arc::new(Float64Array::from_iter_values(std::iter::once(
106            imu.orientation_w,
107        ))),
108        Arc::new(Float64Array::from_iter_values(std::iter::once(
109            imu.orientation_x,
110        ))),
111        Arc::new(Float64Array::from_iter_values(std::iter::once(
112            imu.orientation_y,
113        ))),
114        Arc::new(Float64Array::from_iter_values(std::iter::once(
115            imu.orientation_z,
116        ))),
117        Arc::new(Int64Array::from_iter_values(std::iter::once(
118            imu.timestamp_ns,
119        ))),
120    ];
121    RecordBatch::try_new(schema(), columns)
122}
123
124/// Decode the first row of a `StructArray` into a heap-owned `ImuOwned`.
125///
126/// # Example
127///
128/// ```
129/// use arrow::array::StructArray;
130/// use isaac_sim_arrow::imu::{Imu, to_record_batch, from_struct_array};
131/// let sample = Imu {
132///     frame_id: "imu",
133///     lin_acc_x: 0.0, lin_acc_y: 0.0, lin_acc_z: 9.81,
134///     ang_vel_x: 0.0, ang_vel_y: 0.0, ang_vel_z: 0.0,
135///     orientation_w: 1.0, orientation_x: 0.0, orientation_y: 0.0, orientation_z: 0.0,
136///     timestamp_ns: 42,
137/// };
138/// let batch = to_record_batch(&sample).unwrap();
139/// let array = StructArray::from(batch);
140/// let owned = from_struct_array(&array).unwrap();
141/// assert_eq!(owned.frame_id, "imu");
142/// assert_eq!(owned.timestamp_ns, 42);
143/// ```
144pub fn from_struct_array(array: &StructArray) -> Result<ImuOwned, arrow::error::ArrowError> {
145    if array.is_empty() {
146        return Err(arrow::error::ArrowError::InvalidArgumentError(
147            "imu struct array is empty".into(),
148        ));
149    }
150    let f64_at = |idx: usize, name: &str| -> Result<f64, arrow::error::ArrowError> {
151        array
152            .column(idx)
153            .as_any()
154            .downcast_ref::<Float64Array>()
155            .ok_or_else(|| {
156                arrow::error::ArrowError::SchemaError(format!("imu '{name}' not Float64"))
157            })
158            .map(|a| a.value(0))
159    };
160    let frame_id = array
161        .column(0)
162        .as_any()
163        .downcast_ref::<StringArray>()
164        .ok_or_else(|| arrow::error::ArrowError::SchemaError("imu 'frame_id' not Utf8".into()))?
165        .value(0)
166        .to_string();
167    let timestamp_ns = array
168        .column(11)
169        .as_any()
170        .downcast_ref::<Int64Array>()
171        .ok_or_else(|| {
172            arrow::error::ArrowError::SchemaError("imu 'timestamp_ns' not Int64".into())
173        })?
174        .value(0);
175    Ok(ImuOwned {
176        frame_id,
177        lin_acc_x: f64_at(1, "lin_acc_x")?,
178        lin_acc_y: f64_at(2, "lin_acc_y")?,
179        lin_acc_z: f64_at(3, "lin_acc_z")?,
180        ang_vel_x: f64_at(4, "ang_vel_x")?,
181        ang_vel_y: f64_at(5, "ang_vel_y")?,
182        ang_vel_z: f64_at(6, "ang_vel_z")?,
183        orientation_w: f64_at(7, "orientation_w")?,
184        orientation_x: f64_at(8, "orientation_x")?,
185        orientation_y: f64_at(9, "orientation_y")?,
186        orientation_z: f64_at(10, "orientation_z")?,
187        timestamp_ns,
188    })
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use arrow::array::Array;
195
196    #[test]
197    fn round_trips_through_record_batch() {
198        let imu = Imu {
199            frame_id: "sim_imu",
200            lin_acc_x: 0.1,
201            lin_acc_y: 0.2,
202            lin_acc_z: 9.81,
203            ang_vel_x: 0.0,
204            ang_vel_y: 0.0,
205            ang_vel_z: 0.5,
206            orientation_w: 1.0,
207            orientation_x: 0.0,
208            orientation_y: 0.0,
209            orientation_z: 0.0,
210            timestamp_ns: 12345,
211        };
212        let batch = to_record_batch(&imu).expect("convert");
213        assert_eq!(batch.num_rows(), 1);
214        assert_eq!(batch.num_columns(), 12);
215
216        let frame = batch
217            .column(0)
218            .as_any()
219            .downcast_ref::<StringArray>()
220            .expect("frame_id is Utf8");
221        assert_eq!(frame.value(0), "sim_imu");
222
223        let lin_z = batch
224            .column(3)
225            .as_any()
226            .downcast_ref::<Float64Array>()
227            .expect("lin_acc_z is Float64");
228        assert!((lin_z.value(0) - 9.81).abs() < 1e-9);
229
230        let ts = batch
231            .column(11)
232            .as_any()
233            .downcast_ref::<Int64Array>()
234            .expect("timestamp_ns is Int64");
235        assert_eq!(ts.value(0), 12345);
236    }
237
238    #[test]
239    fn from_struct_array_round_trips() {
240        let imu = Imu {
241            frame_id: "sim_imu",
242            lin_acc_x: 0.1,
243            lin_acc_y: 0.2,
244            lin_acc_z: 9.81,
245            ang_vel_x: 0.0,
246            ang_vel_y: 0.0,
247            ang_vel_z: 0.5,
248            orientation_w: 1.0,
249            orientation_x: 0.0,
250            orientation_y: 0.0,
251            orientation_z: 0.0,
252            timestamp_ns: 12345,
253        };
254        let batch = to_record_batch(&imu).expect("to");
255        let array = StructArray::from(batch);
256        let owned = from_struct_array(&array).expect("from");
257        assert_eq!(owned.frame_id, "sim_imu");
258        assert!((owned.lin_acc_z - 9.81).abs() < 1e-9);
259        assert_eq!(owned.timestamp_ns, 12345);
260    }
261}