Skip to main content

isaac_sim_arrow/
cmd_vel.rs

1// SPDX-License-Identifier: MPL-2.0
2//! Arrow encoder and decoder for the cmd_vel (Twist) actuation channel.
3use std::sync::{Arc, OnceLock};
4
5use arrow::array::{Array, ArrayRef, Float32Array, Int64Array, StructArray};
6use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
7use arrow::record_batch::RecordBatch;
8
9/// A single Twist command: three-axis linear and angular velocities plus a nanosecond timestamp.
10#[derive(Debug, Clone, Copy, PartialEq)]
11#[allow(missing_docs)]
12pub struct CmdVel {
13    pub linear_x: f32,
14    pub linear_y: f32,
15    pub linear_z: f32,
16    pub angular_x: f32,
17    pub angular_y: f32,
18    pub angular_z: f32,
19    pub timestamp_ns: i64,
20}
21
22impl Default for CmdVel {
23    fn default() -> Self {
24        Self {
25            linear_x: 0.0,
26            linear_y: 0.0,
27            linear_z: 0.0,
28            angular_x: 0.0,
29            angular_y: 0.0,
30            angular_z: 0.0,
31            timestamp_ns: 0,
32        }
33    }
34}
35
36/// Stable Arrow schema for a `CmdVel` record batch.
37pub fn schema() -> SchemaRef {
38    static SCHEMA: OnceLock<SchemaRef> = OnceLock::new();
39    SCHEMA
40        .get_or_init(|| {
41            Arc::new(Schema::new(vec![
42                Field::new("linear_x", DataType::Float32, false),
43                Field::new("linear_y", DataType::Float32, false),
44                Field::new("linear_z", DataType::Float32, false),
45                Field::new("angular_x", DataType::Float32, false),
46                Field::new("angular_y", DataType::Float32, false),
47                Field::new("angular_z", DataType::Float32, false),
48                Field::new("timestamp_ns", DataType::Int64, false),
49            ]))
50        })
51        .clone()
52}
53
54/// Encode a `CmdVel` sample as a single-row `RecordBatch` matching [`schema`].
55///
56/// # Example
57///
58/// ```
59/// use isaac_sim_arrow::cmd_vel::{CmdVel, to_record_batch};
60/// let twist = CmdVel { linear_x: 0.5, angular_z: 0.2, ..CmdVel::default() };
61/// let batch = to_record_batch(&twist).unwrap();
62/// assert_eq!(batch.num_rows(), 1);
63/// assert_eq!(batch.num_columns(), 7);
64/// ```
65pub fn to_record_batch(twist: &CmdVel) -> Result<RecordBatch, arrow::error::ArrowError> {
66    let columns: Vec<ArrayRef> = vec![
67        Arc::new(Float32Array::from_iter_values(std::iter::once(
68            twist.linear_x,
69        ))),
70        Arc::new(Float32Array::from_iter_values(std::iter::once(
71            twist.linear_y,
72        ))),
73        Arc::new(Float32Array::from_iter_values(std::iter::once(
74            twist.linear_z,
75        ))),
76        Arc::new(Float32Array::from_iter_values(std::iter::once(
77            twist.angular_x,
78        ))),
79        Arc::new(Float32Array::from_iter_values(std::iter::once(
80            twist.angular_y,
81        ))),
82        Arc::new(Float32Array::from_iter_values(std::iter::once(
83            twist.angular_z,
84        ))),
85        Arc::new(Int64Array::from_iter_values(std::iter::once(
86            twist.timestamp_ns,
87        ))),
88    ];
89    RecordBatch::try_new(schema(), columns)
90}
91
92/// Decode a single CmdVel sample from a `StructArray` whose fields
93/// match [`schema`]. Returns the first row; errors on field mismatch
94/// or empty input. Symmetric to [`to_record_batch`].
95///
96/// # Example
97///
98/// ```
99/// use arrow::array::StructArray;
100/// use isaac_sim_arrow::cmd_vel::{CmdVel, to_record_batch, from_struct_array};
101/// let twist = CmdVel { linear_x: 1.0, angular_z: -0.5, ..CmdVel::default() };
102/// let batch = to_record_batch(&twist).unwrap();
103/// let array = StructArray::from(batch);
104/// let decoded = from_struct_array(&array).unwrap();
105/// assert_eq!(decoded, twist);
106/// ```
107pub fn from_struct_array(array: &StructArray) -> Result<CmdVel, arrow::error::ArrowError> {
108    if array.is_empty() {
109        return Err(arrow::error::ArrowError::InvalidArgumentError(
110            "cmd_vel struct array is empty".into(),
111        ));
112    }
113    let schema = schema();
114    let names = schema.fields().iter().map(|f| f.name().clone());
115    let mut out = CmdVel::default();
116    for (idx, name) in names.enumerate() {
117        let col = array.column(idx);
118        match name.as_str() {
119            "linear_x" => out.linear_x = col_f32(col, "linear_x")?,
120            "linear_y" => out.linear_y = col_f32(col, "linear_y")?,
121            "linear_z" => out.linear_z = col_f32(col, "linear_z")?,
122            "angular_x" => out.angular_x = col_f32(col, "angular_x")?,
123            "angular_y" => out.angular_y = col_f32(col, "angular_y")?,
124            "angular_z" => out.angular_z = col_f32(col, "angular_z")?,
125            "timestamp_ns" => out.timestamp_ns = col_i64(col, "timestamp_ns")?,
126            other => {
127                return Err(arrow::error::ArrowError::SchemaError(format!(
128                    "unexpected cmd_vel column '{other}'"
129                )));
130            }
131        }
132    }
133    Ok(out)
134}
135
136fn col_f32(col: &ArrayRef, name: &str) -> Result<f32, arrow::error::ArrowError> {
137    col.as_any()
138        .downcast_ref::<Float32Array>()
139        .ok_or_else(|| {
140            arrow::error::ArrowError::SchemaError(format!("cmd_vel '{name}' not Float32"))
141        })
142        .map(|a| a.value(0))
143}
144
145fn col_i64(col: &ArrayRef, name: &str) -> Result<i64, arrow::error::ArrowError> {
146    col.as_any()
147        .downcast_ref::<Int64Array>()
148        .ok_or_else(|| arrow::error::ArrowError::SchemaError(format!("cmd_vel '{name}' not Int64")))
149        .map(|a| a.value(0))
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn round_trips_through_struct_array() {
158        let twist = CmdVel {
159            linear_x: 0.4,
160            linear_y: 0.0,
161            linear_z: 0.0,
162            angular_x: 0.0,
163            angular_y: 0.0,
164            angular_z: 0.3,
165            timestamp_ns: 999,
166        };
167        let batch = to_record_batch(&twist).expect("convert");
168        let array = StructArray::from(batch);
169        let decoded = from_struct_array(&array).expect("decode");
170        assert_eq!(decoded, twist);
171    }
172}