1use crate::FrameIdString;
5use crate::frames::{FrameId, FramePair};
6use crate::velocity::VelocityTransform;
7use bincode::{Decode, Encode};
8use cu_spatial_payloads::Transform3D;
9use cu29::clock::{CuTime, CuTimeRange, Tov};
10use cu29::cutask::CuStampedData;
11use cu29::prelude::CuMsgPayload;
12use num_traits;
13use serde::{Deserialize, Serialize};
14use std::fmt::Debug;
15
16pub type StampedFrameTransform<T> = CuStampedData<FrameTransform<T>, ()>;
18
19#[derive(Clone, Debug, Serialize, Deserialize, Default)]
42pub struct FrameTransform<T: Copy + Debug + Default + Serialize + 'static> {
43 pub transform: Transform3D<T>,
45 pub parent_frame: FrameIdString,
47 pub child_frame: FrameIdString,
49}
50
51impl<T: Copy + Debug + Default + Serialize + 'static> FrameTransform<T> {
52 pub fn new(
54 transform: Transform3D<T>,
55 parent_frame: impl AsRef<str>,
56 child_frame: impl AsRef<str>,
57 ) -> Self {
58 Self {
59 transform,
60 parent_frame: FrameIdString::from(parent_frame.as_ref())
61 .expect("Parent frame name too long (max 64 chars)"),
62 child_frame: FrameIdString::from(child_frame.as_ref())
63 .expect("Child frame name too long (max 64 chars)"),
64 }
65 }
66
67 pub fn from_stamped(stamped: &crate::transform::StampedTransform<T>) -> Self {
69 Self {
70 transform: stamped.transform,
71 parent_frame: FrameIdString::from(stamped.parent_frame.as_str())
72 .expect("Parent frame name too long"),
73 child_frame: FrameIdString::from(stamped.child_frame.as_str())
74 .expect("Child frame name too long"),
75 }
76 }
77}
78
79impl<T: Copy + Debug + Default + Serialize + 'static> Encode for FrameTransform<T>
81where
82 T: Encode,
83{
84 fn encode<E: bincode::enc::Encoder>(
85 &self,
86 encoder: &mut E,
87 ) -> Result<(), bincode::error::EncodeError> {
88 self.transform.encode(encoder)?;
89 self.parent_frame.encode(encoder)?;
90 self.child_frame.encode(encoder)?;
91 Ok(())
92 }
93}
94
95impl<T: Copy + Debug + Default + Serialize + 'static> Decode<()> for FrameTransform<T>
96where
97 T: Decode<()>,
98{
99 fn decode<D: bincode::de::Decoder<Context = ()>>(
100 decoder: &mut D,
101 ) -> Result<Self, bincode::error::DecodeError> {
102 let transform = Transform3D::decode(decoder)?;
103 let parent_frame_str = String::decode(decoder)?;
104 let child_frame_str = String::decode(decoder)?;
105 let parent_frame = FrameIdString::from(&parent_frame_str).map_err(|_| {
106 bincode::error::DecodeError::OtherString("Parent frame name too long".to_string())
107 })?;
108 let child_frame = FrameIdString::from(&child_frame_str).map_err(|_| {
109 bincode::error::DecodeError::OtherString("Child frame name too long".to_string())
110 })?;
111 Ok(Self {
112 transform,
113 parent_frame,
114 child_frame,
115 })
116 }
117}
118
119pub type TypedStampedFrameTransform<T> = CuStampedData<Transform3D<T>, ()>;
121
122#[derive(Debug, Clone)]
124pub struct TypedTransform<T, Parent, Child>
125where
126 T: CuMsgPayload + Copy + Debug + 'static,
127 Parent: FrameId,
128 Child: FrameId,
129{
130 pub transform: TypedStampedFrameTransform<T>,
132 pub frames: FramePair<Parent, Child>,
134}
135
136impl<T, Parent, Child> TypedTransform<T, Parent, Child>
137where
138 T: CuMsgPayload + Copy + Debug + 'static,
139 Parent: FrameId,
140 Child: FrameId,
141{
142 pub fn new(transform: Transform3D<T>, time: CuTime) -> Self {
144 let mut transform = TypedStampedFrameTransform::new(Some(transform));
145 transform.tov = Tov::Time(time);
146
147 let frames = FramePair::new();
148
149 Self { transform, frames }
150 }
151
152 pub fn transform(&self) -> Option<&Transform3D<T>> {
154 self.transform.payload()
155 }
156
157 pub fn timestamp(&self) -> Option<CuTime> {
159 match self.transform.tov {
160 Tov::Time(time) => Some(time),
161 _ => None,
162 }
163 }
164
165 pub fn parent_id(&self) -> u32 {
167 Parent::ID
168 }
169
170 pub fn child_id(&self) -> u32 {
172 Child::ID
173 }
174
175 pub fn parent_name(&self) -> &'static str {
177 Parent::NAME
178 }
179
180 pub fn child_name(&self) -> &'static str {
182 Child::NAME
183 }
184}
185
186#[derive(Debug)]
189pub struct TypedTransformBuffer<T, Parent, Child, const N: usize>
190where
191 T: CuMsgPayload + Copy + Debug + 'static,
192 Parent: FrameId,
193 Child: FrameId,
194{
195 transforms: [Option<TypedTransform<T, Parent, Child>>; N],
197 count: usize,
199}
200
201impl<T, Parent, Child, const N: usize> TypedTransformBuffer<T, Parent, Child, N>
202where
203 T: CuMsgPayload + Copy + Debug + 'static,
204 Parent: FrameId,
205 Child: FrameId,
206{
207 pub fn new() -> Self {
209 Self {
210 transforms: std::array::from_fn(|_| None),
211 count: 0,
212 }
213 }
214
215 pub fn add_transform(&mut self, transform_msg: TypedTransform<T, Parent, Child>) {
217 if self.count < N {
218 self.transforms[self.count] = Some(transform_msg);
220 self.count += 1;
221 } else {
222 for i in 0..N - 1 {
224 self.transforms[i] = self.transforms[i + 1].take();
225 }
226 self.transforms[N - 1] = Some(transform_msg);
227 }
228
229 self.sort_by_time();
231 }
232
233 fn sort_by_time(&mut self) {
235 let mut time_indices: Vec<(CuTime, usize)> = Vec::new();
237
238 for i in 0..self.count {
239 if let Some(ref transform) = self.transforms[i]
240 && let Some(time) = transform.timestamp()
241 {
242 time_indices.push((time, i));
243 }
244 }
245
246 time_indices.sort_by_key(|(time, _)| *time);
248
249 let mut new_transforms: [Option<TypedTransform<T, Parent, Child>>; N] =
251 std::array::from_fn(|_| None);
252
253 for (idx, (_, old_idx)) in time_indices.iter().enumerate() {
254 new_transforms[idx] = self.transforms[*old_idx].take();
255 }
256
257 self.transforms = new_transforms;
258 }
259
260 pub fn get_latest_transform(&self) -> Option<&TypedTransform<T, Parent, Child>> {
262 if self.count == 0 {
263 return None;
264 }
265
266 self.transforms[self.count - 1].as_ref()
268 }
269
270 pub fn get_closest_transform(&self, time: CuTime) -> Option<&TypedTransform<T, Parent, Child>> {
272 if self.count == 0 {
273 return None;
274 }
275
276 let mut closest_idx = 0;
277 let mut closest_diff = u64::MAX;
278
279 for i in 0..self.count {
280 if let Some(ref transform) = self.transforms[i]
281 && let Some(transform_time) = transform.timestamp()
282 {
283 let diff = if time.as_nanos() > transform_time.as_nanos() {
284 time.as_nanos() - transform_time.as_nanos()
285 } else {
286 transform_time.as_nanos() - time.as_nanos()
287 };
288
289 if diff < closest_diff {
290 closest_diff = diff;
291 closest_idx = i;
292 }
293 }
294 }
295
296 self.transforms[closest_idx].as_ref()
297 }
298
299 pub fn get_time_range(&self) -> Option<CuTimeRange> {
301 if self.count == 0 {
302 return None;
303 }
304
305 let start = self.transforms[0].as_ref()?.timestamp()?;
307 let end = self.transforms[self.count - 1].as_ref()?.timestamp()?;
308
309 Some(CuTimeRange { start, end })
310 }
311
312 #[allow(clippy::type_complexity)]
314 pub fn get_transforms_around(
315 &self,
316 time: CuTime,
317 ) -> Option<(
318 &TypedTransform<T, Parent, Child>,
319 &TypedTransform<T, Parent, Child>,
320 )> {
321 if self.count < 2 {
322 return None;
323 }
324
325 let mut before_idx = None;
327 let mut after_idx = None;
328
329 for i in 0..self.count {
330 if let Some(ref transform) = self.transforms[i]
331 && let Some(transform_time) = transform.timestamp()
332 {
333 if transform_time <= time {
334 before_idx = Some(i);
335 } else if after_idx.is_none() {
336 after_idx = Some(i);
337 break;
338 }
339 }
340 }
341
342 match (before_idx, after_idx) {
343 (Some(before), Some(after)) => Some((
344 self.transforms[before].as_ref()?,
345 self.transforms[after].as_ref()?,
346 )),
347 (Some(before), None) => {
348 if before > 0 {
350 Some((
351 self.transforms[before - 1].as_ref()?,
352 self.transforms[before].as_ref()?,
353 ))
354 } else {
355 None
356 }
357 }
358 (None, Some(after)) => {
359 if after + 1 < self.count {
361 Some((
362 self.transforms[after].as_ref()?,
363 self.transforms[after + 1].as_ref()?,
364 ))
365 } else {
366 None
367 }
368 }
369 _ => None,
370 }
371 }
372}
373
374impl<T, Parent, Child, const N: usize> Default for TypedTransformBuffer<T, Parent, Child, N>
375where
376 T: CuMsgPayload + Copy + Debug + 'static,
377 Parent: FrameId,
378 Child: FrameId,
379{
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385impl<T, Parent, Child> TypedTransform<T, Parent, Child>
387where
388 T: CuMsgPayload
389 + Copy
390 + Debug
391 + Default
392 + std::ops::Add<Output = T>
393 + std::ops::Sub<Output = T>
394 + std::ops::Mul<Output = T>
395 + std::ops::Div<Output = T>
396 + num_traits::NumCast
397 + 'static,
398 Parent: FrameId,
399 Child: FrameId,
400{
401 pub fn compute_velocity(&self, previous: &Self) -> Option<VelocityTransform<T>> {
403 let current_time = self.timestamp()?;
404 let previous_time = previous.timestamp()?;
405 let current_transform = self.transform()?;
406 let previous_transform = previous.transform()?;
407
408 let dt_nanos = current_time.as_nanos() as i64 - previous_time.as_nanos() as i64;
410 if dt_nanos <= 0 {
411 return None;
412 }
413
414 let dt = dt_nanos as f64 / 1_000_000_000.0;
416
417 let dt_t = num_traits::cast::cast::<f64, T>(dt)?;
418
419 let current_mat = current_transform.to_matrix();
421 let previous_mat = previous_transform.to_matrix();
422 let mut linear_velocity = [T::default(); 3];
423 for (i, vel) in linear_velocity.iter_mut().enumerate() {
424 let pos_diff = current_mat[3][i] - previous_mat[3][i];
425 *vel = pos_diff / dt_t;
426 }
427
428 let angular_velocity = [T::default(); 3];
430
431 Some(VelocityTransform {
432 linear: linear_velocity,
433 angular: angular_velocity,
434 })
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::frames::{RobotFrame, WorldFrame};
442 fn assert_approx_eq(actual: f32, expected: f32, epsilon: f32) {
444 let diff = (actual - expected).abs();
445 assert!(
446 diff <= epsilon,
447 "expected {expected}, got {actual}, difference {diff} exceeds epsilon {epsilon}",
448 );
449 }
450 use cu29::clock::CuDuration;
451
452 type WorldToRobotFrameTransform = TypedTransform<f32, WorldFrame, RobotFrame>;
453 type WorldToRobotBuffer = TypedTransformBuffer<f32, WorldFrame, RobotFrame, 10>;
454
455 #[test]
456 fn test_typed_transform_msg_creation() {
457 let transform = Transform3D::<f32>::default();
458 let time = CuDuration(1000);
459
460 let msg = WorldToRobotFrameTransform::new(transform, time);
461
462 assert_eq!(msg.parent_id(), WorldFrame::ID);
463 assert_eq!(msg.child_id(), RobotFrame::ID);
464 assert_eq!(msg.parent_name(), "world");
465 assert_eq!(msg.child_name(), "robot");
466 assert_eq!(msg.timestamp().unwrap().as_nanos(), 1000);
467 }
468
469 #[test]
470 fn test_typed_transform_buffer() {
471 let mut buffer = WorldToRobotBuffer::new();
472
473 let transform1 = Transform3D::<f32>::default();
474 let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
475
476 let transform2 = Transform3D::<f32>::default();
477 let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2000));
478
479 buffer.add_transform(msg1);
480 buffer.add_transform(msg2);
481
482 let latest = buffer.get_latest_transform().unwrap();
483 assert_eq!(latest.timestamp().unwrap().as_nanos(), 2000);
484
485 let range = buffer.get_time_range().unwrap();
486 assert_eq!(range.start.as_nanos(), 1000);
487 assert_eq!(range.end.as_nanos(), 2000);
488 }
489
490 #[test]
491 fn test_closest_transform() {
492 let mut buffer = WorldToRobotBuffer::new();
493
494 let transform1 = Transform3D::<f32>::default();
495 let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
496
497 let transform2 = Transform3D::<f32>::default();
498 let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(3000));
499
500 buffer.add_transform(msg1);
501 buffer.add_transform(msg2);
502
503 let closest = buffer.get_closest_transform(CuDuration(1500));
504 assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 1000);
505
506 let closest = buffer.get_closest_transform(CuDuration(2500));
507 assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 3000);
508 }
509
510 #[test]
511 fn test_velocity_computation() {
512 use crate::test_utils::translation_transform;
513
514 let transform1 = translation_transform(0.0f32, 0.0, 0.0);
515 let transform2 = translation_transform(1.0f32, 2.0, 0.0);
516
517 let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1_000_000_000)); let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2_000_000_000)); let velocity = msg2.compute_velocity(&msg1).unwrap();
521
522 assert_approx_eq(velocity.linear[0], 1.0, 1e-5);
523 assert_approx_eq(velocity.linear[1], 2.0, 1e-5);
524 assert_approx_eq(velocity.linear[2], 0.0, 1e-5);
525 }
526}