Skip to main content

cu_transform/
transform_payload.rs

1//! Transform message system using CuMsg and compile-time frame types
2//! This replaces the StampedTransform approach with a more Copper-native design
3
4use crate::FrameIdString;
5use crate::frames::{FrameId, FramePair};
6use crate::velocity::VelocityTransform;
7use bincode::{Decode, Encode};
8use cu_spatial_payloads::Transform3D;
9use cu29::clock::{CuTime, CuTimeRange, Tov};
10use cu29::cutask::CuStampedData;
11use cu29::prelude::CuMsgPayload;
12use num_traits;
13use serde::{Deserialize, Serialize};
14use std::fmt::Debug;
15
16/// Transforms are timestamped Relative transforms.
17pub type StampedFrameTransform<T> = CuStampedData<FrameTransform<T>, ()>;
18
19/// Transform message useable as a payload for CuStampedData.
20/// This contains just the transform data without timestamps,
21/// as timestamps are handled by CuStampedData
22///
23/// # Example
24/// ```
25/// use cu_transform::{FrameTransform, Transform3D};
26/// use cu29::prelude::*;
27/// use cu29::clock::{CuDuration, Tov};
28/// use cu_transform::transform_payload::StampedFrameTransform;
29///
30/// // Create a transform message
31/// let transform = Transform3D::<f32>::default();
32/// let payload = FrameTransform::new(
33///     transform,
34///     "world",
35///     "robot"
36/// );
37///
38/// let data = StampedFrameTransform::new(Some(payload));
39///
40/// ```
41#[derive(Clone, Debug, Serialize, Deserialize, Default)]
42pub struct FrameTransform<T: Copy + Debug + Default + Serialize + 'static> {
43    /// The actual transform
44    pub transform: Transform3D<T>,
45    /// Parent frame identifier
46    pub parent_frame: FrameIdString,
47    /// Child frame identifier
48    pub child_frame: FrameIdString,
49}
50
51impl<T: Copy + Debug + Default + Serialize + 'static> FrameTransform<T> {
52    /// Create a new transform message
53    pub fn new(
54        transform: Transform3D<T>,
55        parent_frame: impl AsRef<str>,
56        child_frame: impl AsRef<str>,
57    ) -> Self {
58        Self {
59            transform,
60            parent_frame: FrameIdString::from(parent_frame.as_ref())
61                .expect("Parent frame name too long (max 64 chars)"),
62            child_frame: FrameIdString::from(child_frame.as_ref())
63                .expect("Child frame name too long (max 64 chars)"),
64        }
65    }
66
67    /// Create from a StampedTransform (for migration)
68    pub fn from_stamped(stamped: &crate::transform::StampedTransform<T>) -> Self {
69        Self {
70            transform: stamped.transform,
71            parent_frame: FrameIdString::from(stamped.parent_frame.as_str())
72                .expect("Parent frame name too long"),
73            child_frame: FrameIdString::from(stamped.child_frame.as_str())
74                .expect("Child frame name too long"),
75        }
76    }
77}
78
79// Manual Encode/Decode implementations to work with Transform3D's specific implementations
80impl<T: Copy + Debug + Default + Serialize + 'static> Encode for FrameTransform<T>
81where
82    T: Encode,
83{
84    fn encode<E: bincode::enc::Encoder>(
85        &self,
86        encoder: &mut E,
87    ) -> Result<(), bincode::error::EncodeError> {
88        self.transform.encode(encoder)?;
89        self.parent_frame.encode(encoder)?;
90        self.child_frame.encode(encoder)?;
91        Ok(())
92    }
93}
94
95impl<T: Copy + Debug + Default + Serialize + 'static> Decode<()> for FrameTransform<T>
96where
97    T: Decode<()>,
98{
99    fn decode<D: bincode::de::Decoder<Context = ()>>(
100        decoder: &mut D,
101    ) -> Result<Self, bincode::error::DecodeError> {
102        let transform = Transform3D::decode(decoder)?;
103        let parent_frame_str = String::decode(decoder)?;
104        let child_frame_str = String::decode(decoder)?;
105        let parent_frame = FrameIdString::from(&parent_frame_str).map_err(|_| {
106            bincode::error::DecodeError::OtherString("Parent frame name too long".to_string())
107        })?;
108        let child_frame = FrameIdString::from(&child_frame_str).map_err(|_| {
109            bincode::error::DecodeError::OtherString("Child frame name too long".to_string())
110        })?;
111        Ok(Self {
112            transform,
113            parent_frame,
114            child_frame,
115        })
116    }
117}
118
119/// Transforms are timestamped Relative transforms.
120pub type TypedStampedFrameTransform<T> = CuStampedData<Transform3D<T>, ()>;
121
122/// A typed transform message that carries frame relationship information at compile time
123#[derive(Debug, Clone)]
124pub struct TypedTransform<T, Parent, Child>
125where
126    T: CuMsgPayload + Copy + Debug + 'static,
127    Parent: FrameId,
128    Child: FrameId,
129{
130    /// The actual transform message
131    pub transform: TypedStampedFrameTransform<T>,
132    /// Frame relationship (zero-sized at runtime)
133    pub frames: FramePair<Parent, Child>,
134}
135
136impl<T, Parent, Child> TypedTransform<T, Parent, Child>
137where
138    T: CuMsgPayload + Copy + Debug + 'static,
139    Parent: FrameId,
140    Child: FrameId,
141{
142    /// Create a new typed transform message
143    pub fn new(transform: Transform3D<T>, time: CuTime) -> Self {
144        let mut transform = TypedStampedFrameTransform::new(Some(transform));
145        transform.tov = Tov::Time(time);
146
147        let frames = FramePair::new();
148
149        Self { transform, frames }
150    }
151
152    /// Get the transform data
153    pub fn transform(&self) -> Option<&Transform3D<T>> {
154        self.transform.payload()
155    }
156
157    /// Get the timestamp from the message
158    pub fn timestamp(&self) -> Option<CuTime> {
159        match self.transform.tov {
160            Tov::Time(time) => Some(time),
161            _ => None,
162        }
163    }
164
165    /// Get the parent frame ID
166    pub fn parent_id(&self) -> u32 {
167        Parent::ID
168    }
169
170    /// Get the child frame ID
171    pub fn child_id(&self) -> u32 {
172        Child::ID
173    }
174
175    /// Get the parent frame name
176    pub fn parent_name(&self) -> &'static str {
177        Parent::NAME
178    }
179
180    /// Get the child frame name
181    pub fn child_name(&self) -> &'static str {
182        Child::NAME
183    }
184}
185
186/// Fixed-size transform buffer using compile-time frame types
187/// This replaces the dynamic Vec-based approach with a fixed-size array
188#[derive(Debug)]
189pub struct TypedTransformBuffer<T, Parent, Child, const N: usize>
190where
191    T: CuMsgPayload + Copy + Debug + 'static,
192    Parent: FrameId,
193    Child: FrameId,
194{
195    /// Fixed-size array of transform messages
196    transforms: [Option<TypedTransform<T, Parent, Child>>; N],
197    /// Current number of transforms stored
198    count: usize,
199}
200
201impl<T, Parent, Child, const N: usize> TypedTransformBuffer<T, Parent, Child, N>
202where
203    T: CuMsgPayload + Copy + Debug + 'static,
204    Parent: FrameId,
205    Child: FrameId,
206{
207    /// Create a new typed transform buffer
208    pub fn new() -> Self {
209        Self {
210            transforms: std::array::from_fn(|_| None),
211            count: 0,
212        }
213    }
214
215    /// Add a transform to the buffer
216    pub fn add_transform(&mut self, transform_msg: TypedTransform<T, Parent, Child>) {
217        if self.count < N {
218            // Still have space, just add to the end
219            self.transforms[self.count] = Some(transform_msg);
220            self.count += 1;
221        } else {
222            // Buffer is full, shift everything and add to the end
223            for i in 0..N - 1 {
224                self.transforms[i] = self.transforms[i + 1].take();
225            }
226            self.transforms[N - 1] = Some(transform_msg);
227        }
228
229        // Sort to maintain time ordering
230        self.sort_by_time();
231    }
232
233    fn transform_at(&self, index: usize) -> Option<&TypedTransform<T, Parent, Child>> {
234        self.transforms.get(index)?.as_ref()
235    }
236
237    fn timed_indices(&self) -> Vec<(usize, CuTime)> {
238        (0..self.count)
239            .filter_map(|index| {
240                let transform = self.transform_at(index)?;
241                Some((index, transform.timestamp()?))
242            })
243            .collect()
244    }
245
246    #[allow(clippy::type_complexity)]
247    fn transform_pair(
248        &self,
249        first: usize,
250        second: usize,
251    ) -> Option<(
252        &TypedTransform<T, Parent, Child>,
253        &TypedTransform<T, Parent, Child>,
254    )> {
255        Some((self.transform_at(first)?, self.transform_at(second)?))
256    }
257
258    /// Sort transforms by timestamp
259    fn sort_by_time(&mut self) {
260        let mut time_indices = self.timed_indices();
261
262        // Sort by timestamp
263        time_indices.sort_by_key(|(_, time)| *time);
264
265        // Create a new ordered array
266        let mut new_transforms: [Option<TypedTransform<T, Parent, Child>>; N] =
267            std::array::from_fn(|_| None);
268
269        for (new_idx, (old_idx, _)) in time_indices.into_iter().enumerate() {
270            new_transforms[new_idx] = self.transforms[old_idx].take();
271        }
272
273        self.transforms = new_transforms;
274    }
275
276    /// Get the latest transform
277    pub fn get_latest_transform(&self) -> Option<&TypedTransform<T, Parent, Child>> {
278        self.count
279            .checked_sub(1)
280            .and_then(|index| self.transform_at(index))
281    }
282
283    /// Get transform closest to specified time
284    pub fn get_closest_transform(&self, time: CuTime) -> Option<&TypedTransform<T, Parent, Child>> {
285        if self.count == 0 {
286            return None;
287        }
288
289        let closest_idx = self
290            .timed_indices()
291            .into_iter()
292            .min_by_key(|(_, transform_time)| time.as_nanos().abs_diff(transform_time.as_nanos()))
293            .map(|(index, _)| index)
294            .unwrap_or(0);
295
296        self.transform_at(closest_idx)
297    }
298
299    /// Get time range of stored transforms
300    pub fn get_time_range(&self) -> Option<CuTimeRange> {
301        if self.count == 0 {
302            return None;
303        }
304
305        // Since we maintain sorted order, first is min, last is max
306        let end_index = self.count.checked_sub(1)?;
307        let start = self.transform_at(0)?.timestamp()?;
308        let end = self.transform_at(end_index)?.timestamp()?;
309
310        Some(CuTimeRange { start, end })
311    }
312
313    /// Get two transforms around the specified time for velocity computation
314    #[allow(clippy::type_complexity)]
315    pub fn get_transforms_around(
316        &self,
317        time: CuTime,
318    ) -> Option<(
319        &TypedTransform<T, Parent, Child>,
320        &TypedTransform<T, Parent, Child>,
321    )> {
322        if self.count < 2 {
323            return None;
324        }
325
326        // Find transforms before and after the requested time
327        let mut before_idx = None;
328        let mut after_idx = None;
329
330        for i in 0..self.count {
331            let Some(transform) = self.transform_at(i) else {
332                continue;
333            };
334            let Some(transform_time) = transform.timestamp() else {
335                continue;
336            };
337
338            if transform_time <= time {
339                before_idx = Some(i);
340            } else if after_idx.is_none() {
341                after_idx = Some(i);
342                break;
343            }
344        }
345
346        match (before_idx, after_idx) {
347            (Some(before), Some(after)) => self.transform_pair(before, after),
348            (Some(before), None) if before > 0 => self.transform_pair(before - 1, before),
349            (None, Some(after)) if after + 1 < self.count => self.transform_pair(after, after + 1),
350            _ => None,
351        }
352    }
353}
354
355impl<T, Parent, Child, const N: usize> Default for TypedTransformBuffer<T, Parent, Child, N>
356where
357    T: CuMsgPayload + Copy + Debug + 'static,
358    Parent: FrameId,
359    Child: FrameId,
360{
361    fn default() -> Self {
362        Self::new()
363    }
364}
365
366/// Velocity computation for typed transforms
367impl<T, Parent, Child> TypedTransform<T, Parent, Child>
368where
369    T: CuMsgPayload
370        + Copy
371        + Debug
372        + Default
373        + std::ops::Add<Output = T>
374        + std::ops::Sub<Output = T>
375        + std::ops::Mul<Output = T>
376        + std::ops::Div<Output = T>
377        + num_traits::NumCast
378        + 'static,
379    Parent: FrameId,
380    Child: FrameId,
381{
382    /// Compute velocity from this transform and a previous transform
383    pub fn compute_velocity(&self, previous: &Self) -> Option<VelocityTransform<T>> {
384        let current_time = self.timestamp()?;
385        let previous_time = previous.timestamp()?;
386        let current_transform = self.transform()?;
387        let previous_transform = previous.transform()?;
388
389        // Compute time difference in nanoseconds, then convert to seconds
390        let dt_nanos = current_time.as_nanos() as i64 - previous_time.as_nanos() as i64;
391        if dt_nanos <= 0 {
392            return None;
393        }
394
395        // Convert nanoseconds to seconds (1e9 nanoseconds = 1 second)
396        let dt = dt_nanos as f64 / 1_000_000_000.0;
397
398        let dt_t = num_traits::cast::cast::<f64, T>(dt)?;
399
400        // Extract positions from transforms (column-major format)
401        let current_mat = current_transform.to_matrix();
402        let previous_mat = previous_transform.to_matrix();
403        let mut linear_velocity = [T::default(); 3];
404        for (i, vel) in linear_velocity.iter_mut().enumerate() {
405            let pos_diff = current_mat[3][i] - previous_mat[3][i];
406            *vel = pos_diff / dt_t;
407        }
408
409        // Compute angular velocity (simplified version for now)
410        let angular_velocity = [T::default(); 3];
411
412        Some(VelocityTransform {
413            linear: linear_velocity,
414            angular: angular_velocity,
415        })
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use crate::frames::{RobotFrame, WorldFrame};
423    // Helper function to replace assert_relative_eq
424    fn assert_approx_eq(actual: f32, expected: f32, epsilon: f32) {
425        let diff = (actual - expected).abs();
426        assert!(
427            diff <= epsilon,
428            "expected {expected}, got {actual}, difference {diff} exceeds epsilon {epsilon}",
429        );
430    }
431    use cu29::clock::CuDuration;
432
433    type WorldToRobotFrameTransform = TypedTransform<f32, WorldFrame, RobotFrame>;
434    type WorldToRobotBuffer = TypedTransformBuffer<f32, WorldFrame, RobotFrame, 10>;
435
436    #[test]
437    fn test_typed_transform_msg_creation() {
438        let transform = Transform3D::<f32>::default();
439        let time = CuDuration(1000);
440
441        let msg = WorldToRobotFrameTransform::new(transform, time);
442
443        assert_eq!(msg.parent_id(), WorldFrame::ID);
444        assert_eq!(msg.child_id(), RobotFrame::ID);
445        assert_eq!(msg.parent_name(), "world");
446        assert_eq!(msg.child_name(), "robot");
447        assert_eq!(msg.timestamp().unwrap().as_nanos(), 1000);
448    }
449
450    #[test]
451    fn test_typed_transform_buffer() {
452        let mut buffer = WorldToRobotBuffer::new();
453
454        let transform1 = Transform3D::<f32>::default();
455        let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
456
457        let transform2 = Transform3D::<f32>::default();
458        let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2000));
459
460        buffer.add_transform(msg1);
461        buffer.add_transform(msg2);
462
463        let latest = buffer.get_latest_transform().unwrap();
464        assert_eq!(latest.timestamp().unwrap().as_nanos(), 2000);
465
466        let range = buffer.get_time_range().unwrap();
467        assert_eq!(range.start.as_nanos(), 1000);
468        assert_eq!(range.end.as_nanos(), 2000);
469    }
470
471    #[test]
472    fn test_closest_transform() {
473        let mut buffer = WorldToRobotBuffer::new();
474
475        let transform1 = Transform3D::<f32>::default();
476        let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
477
478        let transform2 = Transform3D::<f32>::default();
479        let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(3000));
480
481        buffer.add_transform(msg1);
482        buffer.add_transform(msg2);
483
484        let closest = buffer.get_closest_transform(CuDuration(1500));
485        assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 1000);
486
487        let closest = buffer.get_closest_transform(CuDuration(2500));
488        assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 3000);
489    }
490
491    #[test]
492    fn test_velocity_computation() {
493        use crate::test_utils::translation_transform;
494
495        let transform1 = translation_transform(0.0f32, 0.0, 0.0);
496        let transform2 = translation_transform(1.0f32, 2.0, 0.0);
497
498        let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1_000_000_000)); // 1 second
499        let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2_000_000_000)); // 2 seconds
500
501        let velocity = msg2.compute_velocity(&msg1).unwrap();
502
503        assert_approx_eq(velocity.linear[0], 1.0, 1e-5);
504        assert_approx_eq(velocity.linear[1], 2.0, 1e-5);
505        assert_approx_eq(velocity.linear[2], 0.0, 1e-5);
506    }
507}