Skip to main content

cu_transform/
transform_payload.rs

1//! Transform message system using CuMsg and compile-time frame types
2//! This replaces the StampedTransform approach with a more Copper-native design
3
4use crate::FrameIdString;
5use crate::frames::{FrameId, FramePair};
6use crate::velocity::VelocityTransform;
7use bincode::{Decode, Encode};
8use cu_spatial_payloads::Transform3D;
9#[allow(unused_imports)]
10use cu29::bevy_reflect;
11use cu29::clock::{CuTime, CuTimeRange, Tov};
12use cu29::cutask::CuStampedData;
13use cu29::prelude::{CuMsgPayload, Reflect};
14use num_traits;
15use serde::{Deserialize, Serialize};
16use std::fmt::Debug;
17
18/// Transforms are timestamped Relative transforms.
19pub type StampedFrameTransform<T> = CuStampedData<FrameTransform<T>, ()>;
20
21/// Transform message useable as a payload for CuStampedData.
22/// This contains just the transform data without timestamps,
23/// as timestamps are handled by CuStampedData
24///
25/// # Example
26/// ```
27/// use cu_transform::{FrameTransform, Transform3D};
28/// use cu29::prelude::*;
29/// use cu29::clock::{CuDuration, Tov};
30/// use cu_transform::transform_payload::StampedFrameTransform;
31///
32/// // Create a transform message
33/// let transform = Transform3D::<f32>::default();
34/// let payload = FrameTransform::new(
35///     transform,
36///     "world",
37///     "robot"
38/// );
39///
40/// let data = StampedFrameTransform::new(Some(payload));
41///
42/// ```
43#[derive(Clone, Debug, Serialize, Deserialize, Default, Reflect)]
44#[reflect(opaque, from_reflect = false, no_field_bounds)]
45pub struct FrameTransform<T: Copy + Debug + Default + Serialize + 'static> {
46    /// The actual transform
47    pub transform: Transform3D<T>,
48    /// Parent frame identifier
49    pub parent_frame: FrameIdString,
50    /// Child frame identifier
51    pub child_frame: FrameIdString,
52}
53
54impl<T: Copy + Debug + Default + Serialize + 'static> FrameTransform<T> {
55    /// Create a new transform message
56    pub fn new(
57        transform: Transform3D<T>,
58        parent_frame: impl AsRef<str>,
59        child_frame: impl AsRef<str>,
60    ) -> Self {
61        Self {
62            transform,
63            parent_frame: FrameIdString::from(parent_frame.as_ref())
64                .expect("Parent frame name too long (max 64 chars)"),
65            child_frame: FrameIdString::from(child_frame.as_ref())
66                .expect("Child frame name too long (max 64 chars)"),
67        }
68    }
69
70    /// Create from a StampedTransform (for migration)
71    pub fn from_stamped(stamped: &crate::transform::StampedTransform<T>) -> Self {
72        Self {
73            transform: stamped.transform,
74            parent_frame: FrameIdString::from(stamped.parent_frame.as_str())
75                .expect("Parent frame name too long"),
76            child_frame: FrameIdString::from(stamped.child_frame.as_str())
77                .expect("Child frame name too long"),
78        }
79    }
80}
81
82// Manual Encode/Decode implementations to work with Transform3D's specific implementations
83impl<T: Copy + Debug + Default + Serialize + 'static> Encode for FrameTransform<T>
84where
85    T: Encode,
86{
87    fn encode<E: bincode::enc::Encoder>(
88        &self,
89        encoder: &mut E,
90    ) -> Result<(), bincode::error::EncodeError> {
91        self.transform.encode(encoder)?;
92        self.parent_frame.encode(encoder)?;
93        self.child_frame.encode(encoder)?;
94        Ok(())
95    }
96}
97
98impl<T: Copy + Debug + Default + Serialize + 'static> Decode<()> for FrameTransform<T>
99where
100    T: Decode<()>,
101{
102    fn decode<D: bincode::de::Decoder<Context = ()>>(
103        decoder: &mut D,
104    ) -> Result<Self, bincode::error::DecodeError> {
105        let transform = <Transform3D<T> as Decode<()>>::decode(decoder)?;
106        let parent_frame_str = String::decode(decoder)?;
107        let child_frame_str = String::decode(decoder)?;
108        let parent_frame = FrameIdString::from(&parent_frame_str).map_err(|_| {
109            bincode::error::DecodeError::OtherString("Parent frame name too long".to_string())
110        })?;
111        let child_frame = FrameIdString::from(&child_frame_str).map_err(|_| {
112            bincode::error::DecodeError::OtherString("Child frame name too long".to_string())
113        })?;
114        Ok(Self {
115            transform,
116            parent_frame,
117            child_frame,
118        })
119    }
120}
121
122/// Transforms are timestamped Relative transforms.
123pub type TypedStampedFrameTransform<T> = CuStampedData<Transform3D<T>, ()>;
124
125/// A typed transform message that carries frame relationship information at compile time
126#[derive(Debug, Clone)]
127pub struct TypedTransform<T, Parent, Child>
128where
129    T: CuMsgPayload + Copy + Debug + 'static,
130    Parent: FrameId,
131    Child: FrameId,
132{
133    /// The actual transform message
134    pub transform: TypedStampedFrameTransform<T>,
135    /// Frame relationship (zero-sized at runtime)
136    pub frames: FramePair<Parent, Child>,
137}
138
139impl<T, Parent, Child> TypedTransform<T, Parent, Child>
140where
141    T: CuMsgPayload + Copy + Debug + 'static,
142    Parent: FrameId,
143    Child: FrameId,
144{
145    /// Create a new typed transform message
146    pub fn new(transform: Transform3D<T>, time: CuTime) -> Self {
147        let mut transform = TypedStampedFrameTransform::new(Some(transform));
148        transform.tov = Tov::Time(time);
149
150        let frames = FramePair::new();
151
152        Self { transform, frames }
153    }
154
155    /// Get the transform data
156    pub fn transform(&self) -> Option<&Transform3D<T>> {
157        self.transform.payload()
158    }
159
160    /// Get the timestamp from the message
161    pub fn timestamp(&self) -> Option<CuTime> {
162        match self.transform.tov {
163            Tov::Time(time) => Some(time),
164            _ => None,
165        }
166    }
167
168    /// Get the parent frame ID
169    pub fn parent_id(&self) -> u32 {
170        Parent::ID
171    }
172
173    /// Get the child frame ID
174    pub fn child_id(&self) -> u32 {
175        Child::ID
176    }
177
178    /// Get the parent frame name
179    pub fn parent_name(&self) -> &'static str {
180        Parent::NAME
181    }
182
183    /// Get the child frame name
184    pub fn child_name(&self) -> &'static str {
185        Child::NAME
186    }
187}
188
189/// Fixed-size transform buffer using compile-time frame types
190/// This replaces the dynamic Vec-based approach with a fixed-size array
191#[derive(Debug)]
192pub struct TypedTransformBuffer<T, Parent, Child, const N: usize>
193where
194    T: CuMsgPayload + Copy + Debug + 'static,
195    Parent: FrameId,
196    Child: FrameId,
197{
198    /// Fixed-size array of transform messages
199    transforms: [Option<TypedTransform<T, Parent, Child>>; N],
200    /// Current number of transforms stored
201    count: usize,
202}
203
204impl<T, Parent, Child, const N: usize> TypedTransformBuffer<T, Parent, Child, N>
205where
206    T: CuMsgPayload + Copy + Debug + 'static,
207    Parent: FrameId,
208    Child: FrameId,
209{
210    /// Create a new typed transform buffer
211    pub fn new() -> Self {
212        Self {
213            transforms: std::array::from_fn(|_| None),
214            count: 0,
215        }
216    }
217
218    /// Add a transform to the buffer
219    pub fn add_transform(&mut self, transform_msg: TypedTransform<T, Parent, Child>) {
220        if self.count < N {
221            // Still have space, just add to the end
222            self.transforms[self.count] = Some(transform_msg);
223            self.count += 1;
224        } else {
225            // Buffer is full, shift everything and add to the end
226            for i in 0..N - 1 {
227                self.transforms[i] = self.transforms[i + 1].take();
228            }
229            self.transforms[N - 1] = Some(transform_msg);
230        }
231
232        // Sort to maintain time ordering
233        self.sort_by_time();
234    }
235
236    fn transform_at(&self, index: usize) -> Option<&TypedTransform<T, Parent, Child>> {
237        self.transforms.get(index)?.as_ref()
238    }
239
240    fn timed_indices(&self) -> Vec<(usize, CuTime)> {
241        (0..self.count)
242            .filter_map(|index| {
243                let transform = self.transform_at(index)?;
244                Some((index, transform.timestamp()?))
245            })
246            .collect()
247    }
248
249    #[allow(clippy::type_complexity)]
250    fn transform_pair(
251        &self,
252        first: usize,
253        second: usize,
254    ) -> Option<(
255        &TypedTransform<T, Parent, Child>,
256        &TypedTransform<T, Parent, Child>,
257    )> {
258        Some((self.transform_at(first)?, self.transform_at(second)?))
259    }
260
261    /// Sort transforms by timestamp
262    fn sort_by_time(&mut self) {
263        let mut time_indices = self.timed_indices();
264
265        // Sort by timestamp
266        time_indices.sort_by_key(|(_, time)| *time);
267
268        // Create a new ordered array
269        let mut new_transforms: [Option<TypedTransform<T, Parent, Child>>; N] =
270            std::array::from_fn(|_| None);
271
272        for (new_idx, (old_idx, _)) in time_indices.into_iter().enumerate() {
273            new_transforms[new_idx] = self.transforms[old_idx].take();
274        }
275
276        self.transforms = new_transforms;
277    }
278
279    /// Get the latest transform
280    pub fn get_latest_transform(&self) -> Option<&TypedTransform<T, Parent, Child>> {
281        self.count
282            .checked_sub(1)
283            .and_then(|index| self.transform_at(index))
284    }
285
286    /// Get transform closest to specified time
287    pub fn get_closest_transform(&self, time: CuTime) -> Option<&TypedTransform<T, Parent, Child>> {
288        if self.count == 0 {
289            return None;
290        }
291
292        let closest_idx = self
293            .timed_indices()
294            .into_iter()
295            .min_by_key(|(_, transform_time)| time.as_nanos().abs_diff(transform_time.as_nanos()))
296            .map(|(index, _)| index)
297            .unwrap_or(0);
298
299        self.transform_at(closest_idx)
300    }
301
302    /// Get time range of stored transforms
303    pub fn get_time_range(&self) -> Option<CuTimeRange> {
304        if self.count == 0 {
305            return None;
306        }
307
308        // Since we maintain sorted order, first is min, last is max
309        let end_index = self.count.checked_sub(1)?;
310        let start = self.transform_at(0)?.timestamp()?;
311        let end = self.transform_at(end_index)?.timestamp()?;
312
313        Some(CuTimeRange { start, end })
314    }
315
316    /// Get two transforms around the specified time for velocity computation
317    #[allow(clippy::type_complexity)]
318    pub fn get_transforms_around(
319        &self,
320        time: CuTime,
321    ) -> Option<(
322        &TypedTransform<T, Parent, Child>,
323        &TypedTransform<T, Parent, Child>,
324    )> {
325        if self.count < 2 {
326            return None;
327        }
328
329        // Find transforms before and after the requested time
330        let mut before_idx = None;
331        let mut after_idx = None;
332
333        for i in 0..self.count {
334            let Some(transform) = self.transform_at(i) else {
335                continue;
336            };
337            let Some(transform_time) = transform.timestamp() else {
338                continue;
339            };
340
341            if transform_time <= time {
342                before_idx = Some(i);
343            } else if after_idx.is_none() {
344                after_idx = Some(i);
345                break;
346            }
347        }
348
349        match (before_idx, after_idx) {
350            (Some(before), Some(after)) => self.transform_pair(before, after),
351            (Some(before), None) if before > 0 => self.transform_pair(before - 1, before),
352            (None, Some(after)) if after + 1 < self.count => self.transform_pair(after, after + 1),
353            _ => None,
354        }
355    }
356}
357
358impl<T, Parent, Child, const N: usize> Default for TypedTransformBuffer<T, Parent, Child, N>
359where
360    T: CuMsgPayload + Copy + Debug + 'static,
361    Parent: FrameId,
362    Child: FrameId,
363{
364    fn default() -> Self {
365        Self::new()
366    }
367}
368
369/// Velocity computation for typed transforms
370impl<T, Parent, Child> TypedTransform<T, Parent, Child>
371where
372    T: CuMsgPayload
373        + Copy
374        + Debug
375        + Default
376        + std::ops::Add<Output = T>
377        + std::ops::Sub<Output = T>
378        + std::ops::Mul<Output = T>
379        + std::ops::Div<Output = T>
380        + num_traits::NumCast
381        + 'static,
382    Parent: FrameId,
383    Child: FrameId,
384{
385    /// Compute velocity from this transform and a previous transform
386    pub fn compute_velocity(&self, previous: &Self) -> Option<VelocityTransform<T>> {
387        let current_time = self.timestamp()?;
388        let previous_time = previous.timestamp()?;
389        let current_transform = self.transform()?;
390        let previous_transform = previous.transform()?;
391
392        // Compute time difference in nanoseconds, then convert to seconds
393        let dt_nanos = current_time.as_nanos() as i64 - previous_time.as_nanos() as i64;
394        if dt_nanos <= 0 {
395            return None;
396        }
397
398        // Convert nanoseconds to seconds (1e9 nanoseconds = 1 second)
399        let dt = dt_nanos as f64 / 1_000_000_000.0;
400
401        let dt_t = num_traits::cast::cast::<f64, T>(dt)?;
402
403        // Extract positions from transforms (column-major format)
404        let current_mat = current_transform.to_matrix();
405        let previous_mat = previous_transform.to_matrix();
406        let mut linear_velocity = [T::default(); 3];
407        for (i, vel) in linear_velocity.iter_mut().enumerate() {
408            let pos_diff = current_mat[3][i] - previous_mat[3][i];
409            *vel = pos_diff / dt_t;
410        }
411
412        // Compute angular velocity (simplified version for now)
413        let angular_velocity = [T::default(); 3];
414
415        Some(VelocityTransform {
416            linear: linear_velocity,
417            angular: angular_velocity,
418        })
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use crate::frames::{RobotFrame, WorldFrame};
426    // Helper function to replace assert_relative_eq
427    fn assert_approx_eq(actual: f32, expected: f32, epsilon: f32) {
428        let diff = (actual - expected).abs();
429        assert!(
430            diff <= epsilon,
431            "expected {expected}, got {actual}, difference {diff} exceeds epsilon {epsilon}",
432        );
433    }
434    use cu29::clock::CuDuration;
435
436    type WorldToRobotFrameTransform = TypedTransform<f32, WorldFrame, RobotFrame>;
437    type WorldToRobotBuffer = TypedTransformBuffer<f32, WorldFrame, RobotFrame, 10>;
438
439    #[test]
440    fn test_typed_transform_msg_creation() {
441        let transform = Transform3D::<f32>::default();
442        let time = CuDuration(1000);
443
444        let msg = WorldToRobotFrameTransform::new(transform, time);
445
446        assert_eq!(msg.parent_id(), WorldFrame::ID);
447        assert_eq!(msg.child_id(), RobotFrame::ID);
448        assert_eq!(msg.parent_name(), "world");
449        assert_eq!(msg.child_name(), "robot");
450        assert_eq!(msg.timestamp().unwrap().as_nanos(), 1000);
451    }
452
453    #[test]
454    fn test_typed_transform_buffer() {
455        let mut buffer = WorldToRobotBuffer::new();
456
457        let transform1 = Transform3D::<f32>::default();
458        let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
459
460        let transform2 = Transform3D::<f32>::default();
461        let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2000));
462
463        buffer.add_transform(msg1);
464        buffer.add_transform(msg2);
465
466        let latest = buffer.get_latest_transform().unwrap();
467        assert_eq!(latest.timestamp().unwrap().as_nanos(), 2000);
468
469        let range = buffer.get_time_range().unwrap();
470        assert_eq!(range.start.as_nanos(), 1000);
471        assert_eq!(range.end.as_nanos(), 2000);
472    }
473
474    #[test]
475    fn test_closest_transform() {
476        let mut buffer = WorldToRobotBuffer::new();
477
478        let transform1 = Transform3D::<f32>::default();
479        let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
480
481        let transform2 = Transform3D::<f32>::default();
482        let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(3000));
483
484        buffer.add_transform(msg1);
485        buffer.add_transform(msg2);
486
487        let closest = buffer.get_closest_transform(CuDuration(1500));
488        assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 1000);
489
490        let closest = buffer.get_closest_transform(CuDuration(2500));
491        assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 3000);
492    }
493
494    #[test]
495    fn test_velocity_computation() {
496        use crate::test_utils::translation_transform;
497
498        let transform1 = translation_transform(0.0f32, 0.0, 0.0);
499        let transform2 = translation_transform(1.0f32, 2.0, 0.0);
500
501        let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1_000_000_000)); // 1 second
502        let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2_000_000_000)); // 2 seconds
503
504        let velocity = msg2.compute_velocity(&msg1).unwrap();
505
506        assert_approx_eq(velocity.linear[0], 1.0, 1e-5);
507        assert_approx_eq(velocity.linear[1], 2.0, 1e-5);
508        assert_approx_eq(velocity.linear[2], 0.0, 1e-5);
509    }
510}