1use std::sync::{Arc, OnceLock};
4
5use arrow::array::{Array, ArrayRef, Float32Array, Int64Array, StructArray};
6use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
7use arrow::record_batch::RecordBatch;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
11#[allow(missing_docs)]
12pub struct CmdVel {
13 pub linear_x: f32,
14 pub linear_y: f32,
15 pub linear_z: f32,
16 pub angular_x: f32,
17 pub angular_y: f32,
18 pub angular_z: f32,
19 pub timestamp_ns: i64,
20}
21
22impl Default for CmdVel {
23 fn default() -> Self {
24 Self {
25 linear_x: 0.0,
26 linear_y: 0.0,
27 linear_z: 0.0,
28 angular_x: 0.0,
29 angular_y: 0.0,
30 angular_z: 0.0,
31 timestamp_ns: 0,
32 }
33 }
34}
35
36pub fn schema() -> SchemaRef {
38 static SCHEMA: OnceLock<SchemaRef> = OnceLock::new();
39 SCHEMA
40 .get_or_init(|| {
41 Arc::new(Schema::new(vec![
42 Field::new("linear_x", DataType::Float32, false),
43 Field::new("linear_y", DataType::Float32, false),
44 Field::new("linear_z", DataType::Float32, false),
45 Field::new("angular_x", DataType::Float32, false),
46 Field::new("angular_y", DataType::Float32, false),
47 Field::new("angular_z", DataType::Float32, false),
48 Field::new("timestamp_ns", DataType::Int64, false),
49 ]))
50 })
51 .clone()
52}
53
54pub fn to_record_batch(twist: &CmdVel) -> Result<RecordBatch, arrow::error::ArrowError> {
66 let columns: Vec<ArrayRef> = vec![
67 Arc::new(Float32Array::from_iter_values(std::iter::once(
68 twist.linear_x,
69 ))),
70 Arc::new(Float32Array::from_iter_values(std::iter::once(
71 twist.linear_y,
72 ))),
73 Arc::new(Float32Array::from_iter_values(std::iter::once(
74 twist.linear_z,
75 ))),
76 Arc::new(Float32Array::from_iter_values(std::iter::once(
77 twist.angular_x,
78 ))),
79 Arc::new(Float32Array::from_iter_values(std::iter::once(
80 twist.angular_y,
81 ))),
82 Arc::new(Float32Array::from_iter_values(std::iter::once(
83 twist.angular_z,
84 ))),
85 Arc::new(Int64Array::from_iter_values(std::iter::once(
86 twist.timestamp_ns,
87 ))),
88 ];
89 RecordBatch::try_new(schema(), columns)
90}
91
92pub fn from_struct_array(array: &StructArray) -> Result<CmdVel, arrow::error::ArrowError> {
108 if array.is_empty() {
109 return Err(arrow::error::ArrowError::InvalidArgumentError(
110 "cmd_vel struct array is empty".into(),
111 ));
112 }
113 let schema = schema();
114 let names = schema.fields().iter().map(|f| f.name().clone());
115 let mut out = CmdVel::default();
116 for (idx, name) in names.enumerate() {
117 let col = array.column(idx);
118 match name.as_str() {
119 "linear_x" => out.linear_x = col_f32(col, "linear_x")?,
120 "linear_y" => out.linear_y = col_f32(col, "linear_y")?,
121 "linear_z" => out.linear_z = col_f32(col, "linear_z")?,
122 "angular_x" => out.angular_x = col_f32(col, "angular_x")?,
123 "angular_y" => out.angular_y = col_f32(col, "angular_y")?,
124 "angular_z" => out.angular_z = col_f32(col, "angular_z")?,
125 "timestamp_ns" => out.timestamp_ns = col_i64(col, "timestamp_ns")?,
126 other => {
127 return Err(arrow::error::ArrowError::SchemaError(format!(
128 "unexpected cmd_vel column '{other}'"
129 )));
130 }
131 }
132 }
133 Ok(out)
134}
135
136fn col_f32(col: &ArrayRef, name: &str) -> Result<f32, arrow::error::ArrowError> {
137 col.as_any()
138 .downcast_ref::<Float32Array>()
139 .ok_or_else(|| {
140 arrow::error::ArrowError::SchemaError(format!("cmd_vel '{name}' not Float32"))
141 })
142 .map(|a| a.value(0))
143}
144
145fn col_i64(col: &ArrayRef, name: &str) -> Result<i64, arrow::error::ArrowError> {
146 col.as_any()
147 .downcast_ref::<Int64Array>()
148 .ok_or_else(|| arrow::error::ArrowError::SchemaError(format!("cmd_vel '{name}' not Int64")))
149 .map(|a| a.value(0))
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn round_trips_through_struct_array() {
158 let twist = CmdVel {
159 linear_x: 0.4,
160 linear_y: 0.0,
161 linear_z: 0.0,
162 angular_x: 0.0,
163 angular_y: 0.0,
164 angular_z: 0.3,
165 timestamp_ns: 999,
166 };
167 let batch = to_record_batch(&twist).expect("convert");
168 let array = StructArray::from(batch);
169 let decoded = from_struct_array(&array).expect("decode");
170 assert_eq!(decoded, twist);
171 }
172}