1use crate::FrameIdString;
5use crate::frames::{FrameId, FramePair};
6use crate::velocity::VelocityTransform;
7use bincode::{Decode, Encode};
8use cu_spatial_payloads::Transform3D;
9#[allow(unused_imports)]
10use cu29::bevy_reflect;
11use cu29::clock::{CuTime, CuTimeRange, Tov};
12use cu29::cutask::CuStampedData;
13use cu29::prelude::{CuMsgPayload, Reflect};
14use num_traits;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18pub type StampedFrameTransform<T> = CuStampedData<FrameTransform<T>, ()>;
20
21#[derive(Clone, Debug, Serialize, Deserialize, Default, Reflect)]
44#[reflect(opaque, from_reflect = false, no_field_bounds)]
45pub struct FrameTransform<T: Copy + Debug + Default + Serialize + 'static> {
46 pub transform: Transform3D<T>,
48 pub parent_frame: FrameIdString,
50 pub child_frame: FrameIdString,
52}
53
54impl<T: Copy + Debug + Default + Serialize + 'static> FrameTransform<T> {
55 pub fn new(
57 transform: Transform3D<T>,
58 parent_frame: impl AsRef<str>,
59 child_frame: impl AsRef<str>,
60 ) -> Self {
61 Self {
62 transform,
63 parent_frame: FrameIdString::from(parent_frame.as_ref())
64 .expect("Parent frame name too long (max 64 chars)"),
65 child_frame: FrameIdString::from(child_frame.as_ref())
66 .expect("Child frame name too long (max 64 chars)"),
67 }
68 }
69
70 pub fn from_stamped(stamped: &crate::transform::StampedTransform<T>) -> Self {
72 Self {
73 transform: stamped.transform,
74 parent_frame: FrameIdString::from(stamped.parent_frame.as_str())
75 .expect("Parent frame name too long"),
76 child_frame: FrameIdString::from(stamped.child_frame.as_str())
77 .expect("Child frame name too long"),
78 }
79 }
80}
81
82impl<T: Copy + Debug + Default + Serialize + 'static> Encode for FrameTransform<T>
84where
85 T: Encode,
86{
87 fn encode<E: bincode::enc::Encoder>(
88 &self,
89 encoder: &mut E,
90 ) -> Result<(), bincode::error::EncodeError> {
91 self.transform.encode(encoder)?;
92 self.parent_frame.encode(encoder)?;
93 self.child_frame.encode(encoder)?;
94 Ok(())
95 }
96}
97
98impl<T: Copy + Debug + Default + Serialize + 'static> Decode<()> for FrameTransform<T>
99where
100 T: Decode<()>,
101{
102 fn decode<D: bincode::de::Decoder<Context = ()>>(
103 decoder: &mut D,
104 ) -> Result<Self, bincode::error::DecodeError> {
105 let transform = <Transform3D<T> as Decode<()>>::decode(decoder)?;
106 let parent_frame_str = String::decode(decoder)?;
107 let child_frame_str = String::decode(decoder)?;
108 let parent_frame = FrameIdString::from(&parent_frame_str).map_err(|_| {
109 bincode::error::DecodeError::OtherString("Parent frame name too long".to_string())
110 })?;
111 let child_frame = FrameIdString::from(&child_frame_str).map_err(|_| {
112 bincode::error::DecodeError::OtherString("Child frame name too long".to_string())
113 })?;
114 Ok(Self {
115 transform,
116 parent_frame,
117 child_frame,
118 })
119 }
120}
121
122pub type TypedStampedFrameTransform<T> = CuStampedData<Transform3D<T>, ()>;
124
125#[derive(Debug, Clone)]
127pub struct TypedTransform<T, Parent, Child>
128where
129 T: CuMsgPayload + Copy + Debug + 'static,
130 Parent: FrameId,
131 Child: FrameId,
132{
133 pub transform: TypedStampedFrameTransform<T>,
135 pub frames: FramePair<Parent, Child>,
137}
138
139impl<T, Parent, Child> TypedTransform<T, Parent, Child>
140where
141 T: CuMsgPayload + Copy + Debug + 'static,
142 Parent: FrameId,
143 Child: FrameId,
144{
145 pub fn new(transform: Transform3D<T>, time: CuTime) -> Self {
147 let mut transform = TypedStampedFrameTransform::new(Some(transform));
148 transform.tov = Tov::Time(time);
149
150 let frames = FramePair::new();
151
152 Self { transform, frames }
153 }
154
155 pub fn transform(&self) -> Option<&Transform3D<T>> {
157 self.transform.payload()
158 }
159
160 pub fn timestamp(&self) -> Option<CuTime> {
162 match self.transform.tov {
163 Tov::Time(time) => Some(time),
164 _ => None,
165 }
166 }
167
168 pub fn parent_id(&self) -> u32 {
170 Parent::ID
171 }
172
173 pub fn child_id(&self) -> u32 {
175 Child::ID
176 }
177
178 pub fn parent_name(&self) -> &'static str {
180 Parent::NAME
181 }
182
183 pub fn child_name(&self) -> &'static str {
185 Child::NAME
186 }
187}
188
189#[derive(Debug)]
192pub struct TypedTransformBuffer<T, Parent, Child, const N: usize>
193where
194 T: CuMsgPayload + Copy + Debug + 'static,
195 Parent: FrameId,
196 Child: FrameId,
197{
198 transforms: [Option<TypedTransform<T, Parent, Child>>; N],
200 count: usize,
202}
203
204impl<T, Parent, Child, const N: usize> TypedTransformBuffer<T, Parent, Child, N>
205where
206 T: CuMsgPayload + Copy + Debug + 'static,
207 Parent: FrameId,
208 Child: FrameId,
209{
210 pub fn new() -> Self {
212 Self {
213 transforms: std::array::from_fn(|_| None),
214 count: 0,
215 }
216 }
217
218 pub fn add_transform(&mut self, transform_msg: TypedTransform<T, Parent, Child>) {
220 if self.count < N {
221 self.transforms[self.count] = Some(transform_msg);
223 self.count += 1;
224 } else {
225 for i in 0..N - 1 {
227 self.transforms[i] = self.transforms[i + 1].take();
228 }
229 self.transforms[N - 1] = Some(transform_msg);
230 }
231
232 self.sort_by_time();
234 }
235
236 fn transform_at(&self, index: usize) -> Option<&TypedTransform<T, Parent, Child>> {
237 self.transforms.get(index)?.as_ref()
238 }
239
240 fn timed_indices(&self) -> Vec<(usize, CuTime)> {
241 (0..self.count)
242 .filter_map(|index| {
243 let transform = self.transform_at(index)?;
244 Some((index, transform.timestamp()?))
245 })
246 .collect()
247 }
248
249 #[allow(clippy::type_complexity)]
250 fn transform_pair(
251 &self,
252 first: usize,
253 second: usize,
254 ) -> Option<(
255 &TypedTransform<T, Parent, Child>,
256 &TypedTransform<T, Parent, Child>,
257 )> {
258 Some((self.transform_at(first)?, self.transform_at(second)?))
259 }
260
261 fn sort_by_time(&mut self) {
263 let mut time_indices = self.timed_indices();
264
265 time_indices.sort_by_key(|(_, time)| *time);
267
268 let mut new_transforms: [Option<TypedTransform<T, Parent, Child>>; N] =
270 std::array::from_fn(|_| None);
271
272 for (new_idx, (old_idx, _)) in time_indices.into_iter().enumerate() {
273 new_transforms[new_idx] = self.transforms[old_idx].take();
274 }
275
276 self.transforms = new_transforms;
277 }
278
279 pub fn get_latest_transform(&self) -> Option<&TypedTransform<T, Parent, Child>> {
281 self.count
282 .checked_sub(1)
283 .and_then(|index| self.transform_at(index))
284 }
285
286 pub fn get_closest_transform(&self, time: CuTime) -> Option<&TypedTransform<T, Parent, Child>> {
288 if self.count == 0 {
289 return None;
290 }
291
292 let closest_idx = self
293 .timed_indices()
294 .into_iter()
295 .min_by_key(|(_, transform_time)| time.as_nanos().abs_diff(transform_time.as_nanos()))
296 .map(|(index, _)| index)
297 .unwrap_or(0);
298
299 self.transform_at(closest_idx)
300 }
301
302 pub fn get_time_range(&self) -> Option<CuTimeRange> {
304 if self.count == 0 {
305 return None;
306 }
307
308 let end_index = self.count.checked_sub(1)?;
310 let start = self.transform_at(0)?.timestamp()?;
311 let end = self.transform_at(end_index)?.timestamp()?;
312
313 Some(CuTimeRange { start, end })
314 }
315
316 #[allow(clippy::type_complexity)]
318 pub fn get_transforms_around(
319 &self,
320 time: CuTime,
321 ) -> Option<(
322 &TypedTransform<T, Parent, Child>,
323 &TypedTransform<T, Parent, Child>,
324 )> {
325 if self.count < 2 {
326 return None;
327 }
328
329 let mut before_idx = None;
331 let mut after_idx = None;
332
333 for i in 0..self.count {
334 let Some(transform) = self.transform_at(i) else {
335 continue;
336 };
337 let Some(transform_time) = transform.timestamp() else {
338 continue;
339 };
340
341 if transform_time <= time {
342 before_idx = Some(i);
343 } else if after_idx.is_none() {
344 after_idx = Some(i);
345 break;
346 }
347 }
348
349 match (before_idx, after_idx) {
350 (Some(before), Some(after)) => self.transform_pair(before, after),
351 (Some(before), None) if before > 0 => self.transform_pair(before - 1, before),
352 (None, Some(after)) if after + 1 < self.count => self.transform_pair(after, after + 1),
353 _ => None,
354 }
355 }
356}
357
358impl<T, Parent, Child, const N: usize> Default for TypedTransformBuffer<T, Parent, Child, N>
359where
360 T: CuMsgPayload + Copy + Debug + 'static,
361 Parent: FrameId,
362 Child: FrameId,
363{
364 fn default() -> Self {
365 Self::new()
366 }
367}
368
369impl<T, Parent, Child> TypedTransform<T, Parent, Child>
371where
372 T: CuMsgPayload
373 + Copy
374 + Debug
375 + Default
376 + std::ops::Add<Output = T>
377 + std::ops::Sub<Output = T>
378 + std::ops::Mul<Output = T>
379 + std::ops::Div<Output = T>
380 + num_traits::NumCast
381 + 'static,
382 Parent: FrameId,
383 Child: FrameId,
384{
385 pub fn compute_velocity(&self, previous: &Self) -> Option<VelocityTransform<T>> {
387 let current_time = self.timestamp()?;
388 let previous_time = previous.timestamp()?;
389 let current_transform = self.transform()?;
390 let previous_transform = previous.transform()?;
391
392 let dt_nanos = current_time.as_nanos() as i64 - previous_time.as_nanos() as i64;
394 if dt_nanos <= 0 {
395 return None;
396 }
397
398 let dt = dt_nanos as f64 / 1_000_000_000.0;
400
401 let dt_t = num_traits::cast::cast::<f64, T>(dt)?;
402
403 let current_mat = current_transform.to_matrix();
405 let previous_mat = previous_transform.to_matrix();
406 let mut linear_velocity = [T::default(); 3];
407 for (i, vel) in linear_velocity.iter_mut().enumerate() {
408 let pos_diff = current_mat[3][i] - previous_mat[3][i];
409 *vel = pos_diff / dt_t;
410 }
411
412 let angular_velocity = [T::default(); 3];
414
415 Some(VelocityTransform {
416 linear: linear_velocity,
417 angular: angular_velocity,
418 })
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::frames::{RobotFrame, WorldFrame};
426 fn assert_approx_eq(actual: f32, expected: f32, epsilon: f32) {
428 let diff = (actual - expected).abs();
429 assert!(
430 diff <= epsilon,
431 "expected {expected}, got {actual}, difference {diff} exceeds epsilon {epsilon}",
432 );
433 }
434 use cu29::clock::CuDuration;
435
436 type WorldToRobotFrameTransform = TypedTransform<f32, WorldFrame, RobotFrame>;
437 type WorldToRobotBuffer = TypedTransformBuffer<f32, WorldFrame, RobotFrame, 10>;
438
439 #[test]
440 fn test_typed_transform_msg_creation() {
441 let transform = Transform3D::<f32>::default();
442 let time = CuDuration(1000);
443
444 let msg = WorldToRobotFrameTransform::new(transform, time);
445
446 assert_eq!(msg.parent_id(), WorldFrame::ID);
447 assert_eq!(msg.child_id(), RobotFrame::ID);
448 assert_eq!(msg.parent_name(), "world");
449 assert_eq!(msg.child_name(), "robot");
450 assert_eq!(msg.timestamp().unwrap().as_nanos(), 1000);
451 }
452
453 #[test]
454 fn test_typed_transform_buffer() {
455 let mut buffer = WorldToRobotBuffer::new();
456
457 let transform1 = Transform3D::<f32>::default();
458 let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
459
460 let transform2 = Transform3D::<f32>::default();
461 let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2000));
462
463 buffer.add_transform(msg1);
464 buffer.add_transform(msg2);
465
466 let latest = buffer.get_latest_transform().unwrap();
467 assert_eq!(latest.timestamp().unwrap().as_nanos(), 2000);
468
469 let range = buffer.get_time_range().unwrap();
470 assert_eq!(range.start.as_nanos(), 1000);
471 assert_eq!(range.end.as_nanos(), 2000);
472 }
473
474 #[test]
475 fn test_closest_transform() {
476 let mut buffer = WorldToRobotBuffer::new();
477
478 let transform1 = Transform3D::<f32>::default();
479 let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
480
481 let transform2 = Transform3D::<f32>::default();
482 let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(3000));
483
484 buffer.add_transform(msg1);
485 buffer.add_transform(msg2);
486
487 let closest = buffer.get_closest_transform(CuDuration(1500));
488 assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 1000);
489
490 let closest = buffer.get_closest_transform(CuDuration(2500));
491 assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 3000);
492 }
493
494 #[test]
495 fn test_velocity_computation() {
496 use crate::test_utils::translation_transform;
497
498 let transform1 = translation_transform(0.0f32, 0.0, 0.0);
499 let transform2 = translation_transform(1.0f32, 2.0, 0.0);
500
501 let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1_000_000_000)); let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2_000_000_000)); let velocity = msg2.compute_velocity(&msg1).unwrap();
505
506 assert_approx_eq(velocity.linear[0], 1.0, 1e-5);
507 assert_approx_eq(velocity.linear[1], 2.0, 1e-5);
508 assert_approx_eq(velocity.linear[2], 0.0, 1e-5);
509 }
510}