cu_transform/
transform_payload.rs

1//! Transform message system using CuMsg and compile-time frame types
2//! This replaces the StampedTransform approach with a more Copper-native design
3
4use crate::FrameIdString;
5use crate::frames::{FrameId, FramePair};
6use crate::velocity::VelocityTransform;
7use bincode::{Decode, Encode};
8use cu_spatial_payloads::Transform3D;
9use cu29::clock::{CuTime, CuTimeRange, Tov};
10use cu29::cutask::CuStampedData;
11use cu29::prelude::CuMsgPayload;
12use num_traits;
13use serde::{Deserialize, Serialize};
14use std::fmt::Debug;
15
16/// Transforms are timestamped Relative transforms.
17pub type StampedFrameTransform<T> = CuStampedData<FrameTransform<T>, ()>;
18
19/// Transform message useable as a payload for CuStampedData.
20/// This contains just the transform data without timestamps,
21/// as timestamps are handled by CuStampedData
22///
23/// # Example
24/// ```
25/// use cu_transform::{FrameTransform, Transform3D};
26/// use cu29::prelude::*;
27/// use cu29::clock::{CuDuration, Tov};
28/// use cu_transform::transform_payload::StampedFrameTransform;
29///
30/// // Create a transform message
31/// let transform = Transform3D::<f32>::default();
32/// let payload = FrameTransform::new(
33///     transform,
34///     "world",
35///     "robot"
36/// );
37///
38/// let data = StampedFrameTransform::new(Some(payload));
39///
40/// ```
41#[derive(Clone, Debug, Serialize, Deserialize, Default)]
42pub struct FrameTransform<T: Copy + Debug + Default + Serialize + 'static> {
43    /// The actual transform
44    pub transform: Transform3D<T>,
45    /// Parent frame identifier
46    pub parent_frame: FrameIdString,
47    /// Child frame identifier
48    pub child_frame: FrameIdString,
49}
50
51impl<T: Copy + Debug + Default + Serialize + 'static> FrameTransform<T> {
52    /// Create a new transform message
53    pub fn new(
54        transform: Transform3D<T>,
55        parent_frame: impl AsRef<str>,
56        child_frame: impl AsRef<str>,
57    ) -> Self {
58        Self {
59            transform,
60            parent_frame: FrameIdString::from(parent_frame.as_ref())
61                .expect("Parent frame name too long (max 64 chars)"),
62            child_frame: FrameIdString::from(child_frame.as_ref())
63                .expect("Child frame name too long (max 64 chars)"),
64        }
65    }
66
67    /// Create from a StampedTransform (for migration)
68    pub fn from_stamped(stamped: &crate::transform::StampedTransform<T>) -> Self {
69        Self {
70            transform: stamped.transform,
71            parent_frame: FrameIdString::from(stamped.parent_frame.as_str())
72                .expect("Parent frame name too long"),
73            child_frame: FrameIdString::from(stamped.child_frame.as_str())
74                .expect("Child frame name too long"),
75        }
76    }
77}
78
79// Manual Encode/Decode implementations to work with Transform3D's specific implementations
80impl<T: Copy + Debug + Default + Serialize + 'static> Encode for FrameTransform<T>
81where
82    T: Encode,
83{
84    fn encode<E: bincode::enc::Encoder>(
85        &self,
86        encoder: &mut E,
87    ) -> Result<(), bincode::error::EncodeError> {
88        self.transform.encode(encoder)?;
89        self.parent_frame.encode(encoder)?;
90        self.child_frame.encode(encoder)?;
91        Ok(())
92    }
93}
94
95impl<T: Copy + Debug + Default + Serialize + 'static> Decode<()> for FrameTransform<T>
96where
97    T: Decode<()>,
98{
99    fn decode<D: bincode::de::Decoder<Context = ()>>(
100        decoder: &mut D,
101    ) -> Result<Self, bincode::error::DecodeError> {
102        let transform = Transform3D::decode(decoder)?;
103        let parent_frame_str = String::decode(decoder)?;
104        let child_frame_str = String::decode(decoder)?;
105        let parent_frame = FrameIdString::from(&parent_frame_str).map_err(|_| {
106            bincode::error::DecodeError::OtherString("Parent frame name too long".to_string())
107        })?;
108        let child_frame = FrameIdString::from(&child_frame_str).map_err(|_| {
109            bincode::error::DecodeError::OtherString("Child frame name too long".to_string())
110        })?;
111        Ok(Self {
112            transform,
113            parent_frame,
114            child_frame,
115        })
116    }
117}
118
119/// Transforms are timestamped Relative transforms.
120pub type TypedStampedFrameTransform<T> = CuStampedData<Transform3D<T>, ()>;
121
122/// A typed transform message that carries frame relationship information at compile time
123#[derive(Debug, Clone)]
124pub struct TypedTransform<T, Parent, Child>
125where
126    T: CuMsgPayload + Copy + Debug + 'static,
127    Parent: FrameId,
128    Child: FrameId,
129{
130    /// The actual transform message
131    pub transform: TypedStampedFrameTransform<T>,
132    /// Frame relationship (zero-sized at runtime)
133    pub frames: FramePair<Parent, Child>,
134}
135
136impl<T, Parent, Child> TypedTransform<T, Parent, Child>
137where
138    T: CuMsgPayload + Copy + Debug + 'static,
139    Parent: FrameId,
140    Child: FrameId,
141{
142    /// Create a new typed transform message
143    pub fn new(transform: Transform3D<T>, time: CuTime) -> Self {
144        let mut transform = TypedStampedFrameTransform::new(Some(transform));
145        transform.tov = Tov::Time(time);
146
147        let frames = FramePair::new();
148
149        Self { transform, frames }
150    }
151
152    /// Get the transform data
153    pub fn transform(&self) -> Option<&Transform3D<T>> {
154        self.transform.payload()
155    }
156
157    /// Get the timestamp from the message
158    pub fn timestamp(&self) -> Option<CuTime> {
159        match self.transform.tov {
160            Tov::Time(time) => Some(time),
161            _ => None,
162        }
163    }
164
165    /// Get the parent frame ID
166    pub fn parent_id(&self) -> u32 {
167        Parent::ID
168    }
169
170    /// Get the child frame ID  
171    pub fn child_id(&self) -> u32 {
172        Child::ID
173    }
174
175    /// Get the parent frame name
176    pub fn parent_name(&self) -> &'static str {
177        Parent::NAME
178    }
179
180    /// Get the child frame name
181    pub fn child_name(&self) -> &'static str {
182        Child::NAME
183    }
184}
185
186/// Fixed-size transform buffer using compile-time frame types
187/// This replaces the dynamic Vec-based approach with a fixed-size array
188#[derive(Debug)]
189pub struct TypedTransformBuffer<T, Parent, Child, const N: usize>
190where
191    T: CuMsgPayload + Copy + Debug + 'static,
192    Parent: FrameId,
193    Child: FrameId,
194{
195    /// Fixed-size array of transform messages
196    transforms: [Option<TypedTransform<T, Parent, Child>>; N],
197    /// Current number of transforms stored
198    count: usize,
199}
200
201impl<T, Parent, Child, const N: usize> TypedTransformBuffer<T, Parent, Child, N>
202where
203    T: CuMsgPayload + Copy + Debug + 'static,
204    Parent: FrameId,
205    Child: FrameId,
206{
207    /// Create a new typed transform buffer
208    pub fn new() -> Self {
209        Self {
210            transforms: std::array::from_fn(|_| None),
211            count: 0,
212        }
213    }
214
215    /// Add a transform to the buffer
216    pub fn add_transform(&mut self, transform_msg: TypedTransform<T, Parent, Child>) {
217        if self.count < N {
218            // Still have space, just add to the end
219            self.transforms[self.count] = Some(transform_msg);
220            self.count += 1;
221        } else {
222            // Buffer is full, shift everything and add to the end
223            for i in 0..N - 1 {
224                self.transforms[i] = self.transforms[i + 1].take();
225            }
226            self.transforms[N - 1] = Some(transform_msg);
227        }
228
229        // Sort to maintain time ordering
230        self.sort_by_time();
231    }
232
233    /// Sort transforms by timestamp
234    fn sort_by_time(&mut self) {
235        // Create a temporary vector of (timestamp, index) pairs
236        let mut time_indices: Vec<(CuTime, usize)> = Vec::new();
237
238        for i in 0..self.count {
239            if let Some(ref transform) = self.transforms[i]
240                && let Some(time) = transform.timestamp()
241            {
242                time_indices.push((time, i));
243            }
244        }
245
246        // Sort by timestamp
247        time_indices.sort_by_key(|(time, _)| *time);
248
249        // Create a new ordered array
250        let mut new_transforms: [Option<TypedTransform<T, Parent, Child>>; N] =
251            std::array::from_fn(|_| None);
252
253        for (idx, (_, old_idx)) in time_indices.iter().enumerate() {
254            new_transforms[idx] = self.transforms[*old_idx].take();
255        }
256
257        self.transforms = new_transforms;
258    }
259
260    /// Get the latest transform
261    pub fn get_latest_transform(&self) -> Option<&TypedTransform<T, Parent, Child>> {
262        if self.count == 0 {
263            return None;
264        }
265
266        // Since we maintain sorted order, the latest is the last one
267        self.transforms[self.count - 1].as_ref()
268    }
269
270    /// Get transform closest to specified time
271    pub fn get_closest_transform(&self, time: CuTime) -> Option<&TypedTransform<T, Parent, Child>> {
272        if self.count == 0 {
273            return None;
274        }
275
276        let mut closest_idx = 0;
277        let mut closest_diff = u64::MAX;
278
279        for i in 0..self.count {
280            if let Some(ref transform) = self.transforms[i]
281                && let Some(transform_time) = transform.timestamp()
282            {
283                let diff = if time.as_nanos() > transform_time.as_nanos() {
284                    time.as_nanos() - transform_time.as_nanos()
285                } else {
286                    transform_time.as_nanos() - time.as_nanos()
287                };
288
289                if diff < closest_diff {
290                    closest_diff = diff;
291                    closest_idx = i;
292                }
293            }
294        }
295
296        self.transforms[closest_idx].as_ref()
297    }
298
299    /// Get time range of stored transforms
300    pub fn get_time_range(&self) -> Option<CuTimeRange> {
301        if self.count == 0 {
302            return None;
303        }
304
305        // Since we maintain sorted order, first is min, last is max
306        let start = self.transforms[0].as_ref()?.timestamp()?;
307        let end = self.transforms[self.count - 1].as_ref()?.timestamp()?;
308
309        Some(CuTimeRange { start, end })
310    }
311
312    /// Get two transforms around the specified time for velocity computation
313    #[allow(clippy::type_complexity)]
314    pub fn get_transforms_around(
315        &self,
316        time: CuTime,
317    ) -> Option<(
318        &TypedTransform<T, Parent, Child>,
319        &TypedTransform<T, Parent, Child>,
320    )> {
321        if self.count < 2 {
322            return None;
323        }
324
325        // Find transforms before and after the requested time
326        let mut before_idx = None;
327        let mut after_idx = None;
328
329        for i in 0..self.count {
330            if let Some(ref transform) = self.transforms[i]
331                && let Some(transform_time) = transform.timestamp()
332            {
333                if transform_time <= time {
334                    before_idx = Some(i);
335                } else if after_idx.is_none() {
336                    after_idx = Some(i);
337                    break;
338                }
339            }
340        }
341
342        match (before_idx, after_idx) {
343            (Some(before), Some(after)) => Some((
344                self.transforms[before].as_ref()?,
345                self.transforms[after].as_ref()?,
346            )),
347            (Some(before), None) => {
348                // Time is after all our transforms, use last two
349                if before > 0 {
350                    Some((
351                        self.transforms[before - 1].as_ref()?,
352                        self.transforms[before].as_ref()?,
353                    ))
354                } else {
355                    None
356                }
357            }
358            (None, Some(after)) => {
359                // Time is before all our transforms, use first two
360                if after + 1 < self.count {
361                    Some((
362                        self.transforms[after].as_ref()?,
363                        self.transforms[after + 1].as_ref()?,
364                    ))
365                } else {
366                    None
367                }
368            }
369            _ => None,
370        }
371    }
372}
373
374impl<T, Parent, Child, const N: usize> Default for TypedTransformBuffer<T, Parent, Child, N>
375where
376    T: CuMsgPayload + Copy + Debug + 'static,
377    Parent: FrameId,
378    Child: FrameId,
379{
380    fn default() -> Self {
381        Self::new()
382    }
383}
384
385/// Velocity computation for typed transforms
386impl<T, Parent, Child> TypedTransform<T, Parent, Child>
387where
388    T: CuMsgPayload
389        + Copy
390        + Debug
391        + Default
392        + std::ops::Add<Output = T>
393        + std::ops::Sub<Output = T>
394        + std::ops::Mul<Output = T>
395        + std::ops::Div<Output = T>
396        + num_traits::NumCast
397        + 'static,
398    Parent: FrameId,
399    Child: FrameId,
400{
401    /// Compute velocity from this transform and a previous transform
402    pub fn compute_velocity(&self, previous: &Self) -> Option<VelocityTransform<T>> {
403        let current_time = self.timestamp()?;
404        let previous_time = previous.timestamp()?;
405        let current_transform = self.transform()?;
406        let previous_transform = previous.transform()?;
407
408        // Compute time difference in nanoseconds, then convert to seconds
409        let dt_nanos = current_time.as_nanos() as i64 - previous_time.as_nanos() as i64;
410        if dt_nanos <= 0 {
411            return None;
412        }
413
414        // Convert nanoseconds to seconds (1e9 nanoseconds = 1 second)
415        let dt = dt_nanos as f64 / 1_000_000_000.0;
416
417        let dt_t = num_traits::cast::cast::<f64, T>(dt)?;
418
419        // Extract positions from transforms (column-major format)
420        let current_mat = current_transform.to_matrix();
421        let previous_mat = previous_transform.to_matrix();
422        let mut linear_velocity = [T::default(); 3];
423        for (i, vel) in linear_velocity.iter_mut().enumerate() {
424            let pos_diff = current_mat[3][i] - previous_mat[3][i];
425            *vel = pos_diff / dt_t;
426        }
427
428        // Compute angular velocity (simplified version for now)
429        let angular_velocity = [T::default(); 3];
430
431        Some(VelocityTransform {
432            linear: linear_velocity,
433            angular: angular_velocity,
434        })
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::frames::{RobotFrame, WorldFrame};
442    // Helper function to replace assert_relative_eq
443    fn assert_approx_eq(actual: f32, expected: f32, epsilon: f32) {
444        let diff = (actual - expected).abs();
445        assert!(
446            diff <= epsilon,
447            "expected {expected}, got {actual}, difference {diff} exceeds epsilon {epsilon}",
448        );
449    }
450    use cu29::clock::CuDuration;
451
452    type WorldToRobotFrameTransform = TypedTransform<f32, WorldFrame, RobotFrame>;
453    type WorldToRobotBuffer = TypedTransformBuffer<f32, WorldFrame, RobotFrame, 10>;
454
455    #[test]
456    fn test_typed_transform_msg_creation() {
457        let transform = Transform3D::<f32>::default();
458        let time = CuDuration(1000);
459
460        let msg = WorldToRobotFrameTransform::new(transform, time);
461
462        assert_eq!(msg.parent_id(), WorldFrame::ID);
463        assert_eq!(msg.child_id(), RobotFrame::ID);
464        assert_eq!(msg.parent_name(), "world");
465        assert_eq!(msg.child_name(), "robot");
466        assert_eq!(msg.timestamp().unwrap().as_nanos(), 1000);
467    }
468
469    #[test]
470    fn test_typed_transform_buffer() {
471        let mut buffer = WorldToRobotBuffer::new();
472
473        let transform1 = Transform3D::<f32>::default();
474        let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
475
476        let transform2 = Transform3D::<f32>::default();
477        let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2000));
478
479        buffer.add_transform(msg1);
480        buffer.add_transform(msg2);
481
482        let latest = buffer.get_latest_transform().unwrap();
483        assert_eq!(latest.timestamp().unwrap().as_nanos(), 2000);
484
485        let range = buffer.get_time_range().unwrap();
486        assert_eq!(range.start.as_nanos(), 1000);
487        assert_eq!(range.end.as_nanos(), 2000);
488    }
489
490    #[test]
491    fn test_closest_transform() {
492        let mut buffer = WorldToRobotBuffer::new();
493
494        let transform1 = Transform3D::<f32>::default();
495        let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1000));
496
497        let transform2 = Transform3D::<f32>::default();
498        let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(3000));
499
500        buffer.add_transform(msg1);
501        buffer.add_transform(msg2);
502
503        let closest = buffer.get_closest_transform(CuDuration(1500));
504        assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 1000);
505
506        let closest = buffer.get_closest_transform(CuDuration(2500));
507        assert_eq!(closest.unwrap().timestamp().unwrap().as_nanos(), 3000);
508    }
509
510    #[test]
511    fn test_velocity_computation() {
512        use crate::test_utils::translation_transform;
513
514        let transform1 = translation_transform(0.0f32, 0.0, 0.0);
515        let transform2 = translation_transform(1.0f32, 2.0, 0.0);
516
517        let msg1 = WorldToRobotFrameTransform::new(transform1, CuDuration(1_000_000_000)); // 1 second
518        let msg2 = WorldToRobotFrameTransform::new(transform2, CuDuration(2_000_000_000)); // 2 seconds
519
520        let velocity = msg2.compute_velocity(&msg1).unwrap();
521
522        assert_approx_eq(velocity.linear[0], 1.0, 1e-5);
523        assert_approx_eq(velocity.linear[1], 2.0, 1e-5);
524        assert_approx_eq(velocity.linear[2], 0.0, 1e-5);
525    }
526}