Skip to main content

cu_transform/
tree.rs

1use crate::FrameIdString;
2use crate::error::{TransformError, TransformResult};
3use crate::transform::{StampedTransform, TransformBuffer, TransformStore};
4use crate::transform_payload::StampedFrameTransform;
5use crate::velocity::VelocityTransform;
6use crate::velocity_cache::VelocityTransformCache;
7use cu_spatial_payloads::Transform3D;
8use cu29::clock::{CuTime, RobotClock, Tov};
9use dashmap::DashMap;
10use petgraph::algo::dijkstra;
11use petgraph::graph::{DiGraph, NodeIndex};
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use std::collections::HashMap;
15use std::fmt::Debug;
16use std::ops::Neg;
17use std::time::Duration;
18
19/// Trait for types that can compute their inverse transformation
20pub trait HasInverse<T: Copy + Debug + 'static> {
21    fn inverse(&self) -> Self;
22}
23
24impl HasInverse<f32> for Transform3D<f32> {
25    fn inverse(&self) -> Self {
26        self.inverse()
27    }
28}
29
30impl HasInverse<f64> for Transform3D<f64> {
31    fn inverse(&self) -> Self {
32        self.inverse()
33    }
34}
35
36/// The cache entry for a transform query
37#[derive(Clone)]
38struct TransformCacheEntry<T: Copy + Debug + 'static> {
39    /// The cached transform result
40    transform: Transform3D<T>,
41    /// The time for which this transform was calculated
42    time: CuTime,
43    /// When this cache entry was last accessed
44    last_access: CuTime,
45    /// Path used to calculate this transform
46    path_hash: u64,
47}
48
49/// A cache for transforms to avoid recalculating frequently accessed paths
50struct TransformCache<T: Copy + Debug + 'static> {
51    /// Map from (source, target) frames to cached transforms
52    entries: DashMap<(FrameIdString, FrameIdString), TransformCacheEntry<T>>,
53    /// Maximum size of the cache
54    max_size: usize,
55    /// Maximum age of cache entries before invalidation (in nanoseconds)
56    max_age_nanos: u64,
57    /// Last time the cache was cleaned up (needs to be mutable)
58    last_cleanup_cell: std::sync::Mutex<CuTime>,
59    /// Cleanup interval (in nanoseconds)
60    cleanup_interval_nanos: u64,
61}
62
63impl<T: Copy + Debug + 'static> TransformCache<T> {
64    fn new(max_size: usize, max_age: Duration) -> Self {
65        Self {
66            entries: DashMap::with_capacity(max_size),
67            max_size,
68            max_age_nanos: max_age.as_nanos() as u64,
69            last_cleanup_cell: std::sync::Mutex::new(CuTime::from(0u64)),
70            cleanup_interval_nanos: 5_000_000_000, // Clean every 5 seconds
71        }
72    }
73
74    /// Get a cached transform if it exists and is still valid
75    fn get(
76        &self,
77        from: &str,
78        to: &str,
79        time: CuTime,
80        path_hash: u64,
81        robot_clock: &RobotClock,
82    ) -> Option<Transform3D<T>> {
83        let key = (
84            FrameIdString::from(from).expect("Frame name too long"),
85            FrameIdString::from(to).expect("Frame name too long"),
86        );
87
88        if let Some(mut entry) = self.entries.get_mut(&key) {
89            let now = robot_clock.now();
90
91            // Check if the cache entry is for the same time and path
92            if entry.time == time && entry.path_hash == path_hash {
93                // Check if the entry is still valid (not too old)
94                let age = now.as_nanos().saturating_sub(entry.last_access.as_nanos());
95                if age <= self.max_age_nanos {
96                    // Update last access time
97                    entry.last_access = now;
98                    return Some(entry.transform);
99                }
100            }
101        }
102
103        None
104    }
105
106    /// Add a new transform to the cache
107    fn insert(
108        &self,
109        from: &str,
110        to: &str,
111        transform: Transform3D<T>,
112        time: CuTime,
113        path_hash: u64,
114        robot_clock: &RobotClock,
115    ) {
116        let now = robot_clock.now();
117        let key = (
118            FrameIdString::from(from).expect("Frame name too long"),
119            FrameIdString::from(to).expect("Frame name too long"),
120        );
121
122        // If the cache is at capacity, remove the oldest entry
123        if self.entries.len() >= self.max_size {
124            let oldest_key = self
125                .entries
126                .iter()
127                .min_by_key(|entry| entry.last_access)
128                .map(|entry| *entry.key());
129            if let Some(key_to_remove) = oldest_key {
130                self.entries.remove(&key_to_remove);
131            }
132        }
133
134        // Insert the new entry
135        self.entries.insert(
136            key,
137            TransformCacheEntry {
138                transform,
139                time,
140                last_access: now,
141                path_hash,
142            },
143        );
144    }
145
146    /// Check if it's time to clean up the cache
147    fn should_cleanup(&self, robot_clock: &RobotClock) -> bool {
148        let now = robot_clock.now();
149        let last_cleanup = *self.last_cleanup_cell.lock().unwrap();
150        let elapsed = now.as_nanos().saturating_sub(last_cleanup.as_nanos());
151        elapsed >= self.cleanup_interval_nanos
152    }
153
154    /// Clear old entries from the cache
155    fn cleanup(&self, robot_clock: &RobotClock) {
156        let now = robot_clock.now();
157        let mut keys_to_remove = Vec::new();
158
159        // Identify keys to remove
160        for entry in self.entries.iter() {
161            let age = now.as_nanos().saturating_sub(entry.last_access.as_nanos());
162            if age > self.max_age_nanos {
163                keys_to_remove.push(*entry.key());
164            }
165        }
166
167        // Remove expired entries
168        for key in keys_to_remove {
169            self.entries.remove(&key);
170        }
171
172        // Update last cleanup time
173        *self.last_cleanup_cell.lock().unwrap() = now;
174    }
175
176    /// Clear all entries
177    fn clear(&self) {
178        self.entries.clear();
179    }
180}
181
182pub struct TransformTree<T: Copy + Debug + Default + 'static> {
183    graph: DiGraph<FrameIdString, ()>,
184    frame_indices: HashMap<FrameIdString, NodeIndex>,
185    transform_store: TransformStore<T>,
186    // Concurrent cache for transform lookups
187    cache: TransformCache<T>,
188    // Concurrent cache for velocity transform lookups
189    velocity_cache: VelocityTransformCache<T>,
190}
191
192/// Trait for types that can provide a value representing "one"
193pub trait One {
194    /// Returns a value representing "one" for this type
195    fn one() -> Self;
196}
197
198// Implement One for common numeric types
199impl One for f32 {
200    fn one() -> Self {
201        1.0
202    }
203}
204
205impl One for f64 {
206    fn one() -> Self {
207        1.0
208    }
209}
210
211impl One for i32 {
212    fn one() -> Self {
213        1
214    }
215}
216
217impl One for i64 {
218    fn one() -> Self {
219        1
220    }
221}
222
223impl One for u32 {
224    fn one() -> Self {
225        1
226    }
227}
228
229impl One for u64 {
230    fn one() -> Self {
231        1
232    }
233}
234
235// We need to limit T to types where Transform3D<T> has Clone and inverse method
236// and now we also require T to implement One
237impl<T: Copy + Debug + Default + One + Serialize + DeserializeOwned + 'static + Neg<Output = T>>
238    TransformTree<T>
239where
240    Transform3D<T>: Clone + HasInverse<T> + std::ops::Mul<Output = Transform3D<T>>,
241    T: std::ops::Add<Output = T>
242        + std::ops::Sub<Output = T>
243        + std::ops::Mul<Output = T>
244        + std::ops::Div<Output = T>
245        + std::ops::AddAssign
246        + std::ops::SubAssign
247        + num_traits::NumCast,
248{
249    /// Default cache size (number of transforms to cache)
250    const DEFAULT_CACHE_SIZE: usize = 100;
251
252    /// Default cache entry lifetime (5 seconds)
253    const DEFAULT_CACHE_AGE: Duration = Duration::from_secs(5);
254
255    /// Create a new transform tree with default settings
256    pub fn new() -> Self {
257        Self {
258            graph: DiGraph::new(),
259            frame_indices: HashMap::new(),
260            transform_store: TransformStore::new(),
261            cache: TransformCache::new(Self::DEFAULT_CACHE_SIZE, Self::DEFAULT_CACHE_AGE),
262            velocity_cache: VelocityTransformCache::new(
263                Self::DEFAULT_CACHE_SIZE,
264                Self::DEFAULT_CACHE_AGE.as_nanos() as u64,
265            ),
266        }
267    }
268
269    /// Create a new transform tree with custom cache settings
270    pub fn with_cache_settings(cache_size: usize, cache_age: Duration) -> Self {
271        Self {
272            graph: DiGraph::new(),
273            frame_indices: HashMap::new(),
274            transform_store: TransformStore::new(),
275            cache: TransformCache::new(cache_size, cache_age),
276            velocity_cache: VelocityTransformCache::new(cache_size, cache_age.as_nanos() as u64),
277        }
278    }
279
280    /// Clear the transform cache
281    pub fn clear_cache(&self) {
282        self.cache.clear();
283        self.velocity_cache.clear();
284    }
285
286    /// Perform scheduled cache cleanup operation
287    pub fn cleanup_cache(&self, robot_clock: &RobotClock) {
288        self.cache.cleanup(robot_clock);
289        self.velocity_cache.cleanup(robot_clock);
290    }
291
292    /// Creates an identity transform matrix in a type-safe way
293    fn create_identity_transform() -> Transform3D<T> {
294        let mut mat = [[T::default(); 4]; 4];
295
296        // Set the diagonal elements to one
297        let one = T::one();
298        mat[0][0] = one;
299        mat[1][1] = one;
300        mat[2][2] = one;
301        mat[3][3] = one;
302
303        Transform3D::from_matrix(mat)
304    }
305
306    fn ensure_frame(&mut self, frame_id: &str) -> NodeIndex {
307        let frame_id_string = FrameIdString::from(frame_id).expect("Frame name too long");
308        *self
309            .frame_indices
310            .entry(frame_id_string)
311            .or_insert_with(|| self.graph.add_node(frame_id_string))
312    }
313
314    fn get_segment_buffer(
315        &self,
316        parent: &FrameIdString,
317        child: &FrameIdString,
318    ) -> TransformResult<TransformBuffer<T>> {
319        self.transform_store
320            .get_buffer(parent, child)
321            .ok_or_else(|| TransformError::TransformNotFound {
322                from: parent.to_string(),
323                to: child.to_string(),
324            })
325    }
326
327    /// add a transform to the tree.
328    pub fn add_transform(&mut self, sft: &StampedFrameTransform<T>) -> TransformResult<()>
329    where
330        T: bincode::Encode + bincode::Decode<()>,
331    {
332        let transform_msg = sft.payload().ok_or_else(|| {
333            TransformError::Unknown("Failed to get transform payload".to_string())
334        })?;
335
336        let timestamp = match sft.tov {
337            Tov::Time(time) => time,
338            Tov::Range(range) => range.start, // Use start of range
339            _ => {
340                return Err(TransformError::Unknown(
341                    "Invalid Time of Validity".to_string(),
342                ));
343            }
344        };
345
346        // Ensure frames exist in the graph
347        let parent_idx = self.ensure_frame(&transform_msg.parent_frame);
348        let child_idx = self.ensure_frame(&transform_msg.child_frame);
349
350        // Check for cycles
351        if self.would_create_cycle(parent_idx, child_idx) {
352            return Err(TransformError::CyclicTransformTree);
353        }
354
355        // Add edge if it doesn't exist
356        if !self.graph.contains_edge(parent_idx, child_idx) {
357            self.graph.add_edge(parent_idx, child_idx, ());
358        }
359
360        // Clear velocity cache since we're adding a transform that could change velocities
361        self.velocity_cache.clear();
362
363        // Create StampedTransform for the store (internal implementation detail)
364        let stamped = StampedTransform {
365            transform: transform_msg.transform,
366            stamp: timestamp,
367            parent_frame: transform_msg.parent_frame,
368            child_frame: transform_msg.child_frame,
369        };
370
371        // Add transform to the store
372        self.transform_store.add_transform(stamped);
373        Ok(())
374    }
375
376    fn would_create_cycle(&self, parent: NodeIndex, child: NodeIndex) -> bool {
377        if self.graph.contains_edge(parent, child) {
378            return false;
379        }
380
381        matches!(dijkstra(&self.graph, child, Some(parent), |_| 1), result if result.contains_key(&parent))
382    }
383
384    pub fn find_path(
385        &self,
386        from_frame: &str,
387        to_frame: &str,
388    ) -> TransformResult<Vec<(FrameIdString, FrameIdString, bool)>> {
389        // If frames are the same, return empty path (identity transform)
390        if from_frame == to_frame {
391            return Ok(Vec::new());
392        }
393
394        let from_frame_id = FrameIdString::from(from_frame).expect("Frame name too long");
395        let from_idx = self
396            .frame_indices
397            .get(&from_frame_id)
398            .ok_or(TransformError::FrameNotFound(from_frame.to_string()))?;
399
400        let to_frame_id = FrameIdString::from(to_frame).expect("Frame name too long");
401        let to_idx = self
402            .frame_indices
403            .get(&to_frame_id)
404            .ok_or(TransformError::FrameNotFound(to_frame.to_string()))?;
405
406        // Create an undirected version of the graph to find any path (forward or inverse)
407        let mut undirected_graph = self.graph.clone();
408
409        // Add reverse edges for every existing edge to make it undirected
410        let edges: Vec<_> = self.graph.edge_indices().collect();
411        for edge_idx in edges {
412            let (a, b) = self.graph.edge_endpoints(edge_idx).unwrap();
413            if !undirected_graph.contains_edge(b, a) {
414                undirected_graph.add_edge(b, a, ());
415            }
416        }
417
418        // Now find path in undirected graph
419        let path = dijkstra(&undirected_graph, *from_idx, Some(*to_idx), |_| 1);
420
421        if !path.contains_key(to_idx) {
422            return Err(TransformError::TransformNotFound {
423                from: from_frame.to_string(),
424                to: to_frame.to_string(),
425            });
426        }
427
428        // Reconstruct the path
429        let mut current = *to_idx;
430        let mut path_nodes = vec![current];
431
432        while current != *from_idx {
433            let mut found_next = false;
434
435            // Try all neighbors in undirected graph
436            for neighbor in undirected_graph.neighbors(current) {
437                if path.contains_key(&neighbor) && path[&neighbor] < path[&current] {
438                    current = neighbor;
439                    path_nodes.push(current);
440                    found_next = true;
441                    break;
442                }
443            }
444
445            if !found_next {
446                return Err(TransformError::TransformNotFound {
447                    from: from_frame.to_string(),
448                    to: to_frame.to_string(),
449                });
450            }
451        }
452
453        path_nodes.reverse();
454
455        // Convert node path to edge path with direction information
456        let mut path_edges = Vec::new();
457        for i in 0..path_nodes.len() - 1 {
458            let parent_idx = path_nodes[i];
459            let child_idx = path_nodes[i + 1];
460
461            let parent_frame = self.graph[parent_idx];
462            let child_frame = self.graph[child_idx];
463
464            // Check if this is a forward edge in original directed graph
465            let is_forward = self.graph.contains_edge(parent_idx, child_idx);
466
467            if is_forward {
468                // Forward edge: parent -> child
469                path_edges.push((parent_frame, child_frame, false));
470            } else {
471                // Inverse edge: child <- parent (we need child -> parent)
472                path_edges.push((child_frame, parent_frame, true));
473            }
474        }
475
476        Ok(path_edges)
477    }
478
479    /// Compute a simple hash value for a path to use as a cache key
480    fn compute_path_hash(path: &[(FrameIdString, FrameIdString, bool)]) -> u64 {
481        use std::collections::hash_map::DefaultHasher;
482        use std::hash::{Hash, Hasher};
483
484        let mut hasher = DefaultHasher::new();
485
486        for (parent, child, inverse) in path {
487            parent.hash(&mut hasher);
488            child.hash(&mut hasher);
489            inverse.hash(&mut hasher);
490        }
491
492        hasher.finish()
493    }
494
495    pub fn lookup_transform(
496        &self,
497        from_frame: &str,
498        to_frame: &str,
499        time: CuTime,
500        robot_clock: &RobotClock,
501    ) -> TransformResult<Transform3D<T>> {
502        // Identity case: same frame
503        if from_frame == to_frame {
504            return Ok(Self::create_identity_transform());
505        }
506
507        // Find the path between frames
508        let path = self.find_path(from_frame, to_frame)?;
509
510        if path.is_empty() {
511            // Empty path is another case for identity transform
512            return Ok(Self::create_identity_transform());
513        }
514
515        // Calculate a hash for the path (for cache lookups)
516        let path_hash = Self::compute_path_hash(&path);
517
518        // Try to get the transform from cache - concurrent map allows lock-free reads
519        if let Some(cached_transform) =
520            self.cache
521                .get(from_frame, to_frame, time, path_hash, robot_clock)
522        {
523            return Ok(cached_transform);
524        }
525
526        // Check if it's time to clean up the cache
527        if self.cache.should_cleanup(robot_clock) {
528            self.cache.cleanup(robot_clock);
529        }
530
531        // Cache miss - compute the transform
532
533        // Compose multiple transforms along the path
534        let mut result = Self::create_identity_transform();
535
536        // Iterate through each segment of the path
537        for (parent, child, inverse) in &path {
538            let buffer = self.get_segment_buffer(parent, child)?;
539
540            let transform = buffer
541                .get_closest_transform(time)
542                .ok_or(TransformError::TransformTimeNotAvailable(time))?;
543
544            // Note: In transform composition, the right-most transform is applied first.
545            let transform_to_apply = if *inverse {
546                transform.transform.inverse()
547            } else {
548                transform.transform
549            };
550            result = transform_to_apply * result;
551        }
552
553        // Cache the computed result
554        self.cache
555            .insert(from_frame, to_frame, result, time, path_hash, robot_clock);
556
557        Ok(result)
558    }
559
560    /// Look up the velocity of a frame at a specific time
561    ///
562    /// This computes the velocity by differentiating transforms over time.
563    /// Returns the velocity expressed in the target frame.
564    ///
565    /// Results are automatically cached for improved performance. The cache is
566    /// invalidated when new transforms are added or when cache entries expire based
567    /// on their age. The cache significantly improves performance for repeated lookups
568    /// of the same frames and times.
569    ///
570    /// # Arguments
571    /// * `from_frame` - The source frame
572    /// * `to_frame` - The target frame
573    /// * `time` - The time at which to compute the velocity
574    ///
575    /// # Returns
576    /// * A VelocityTransform containing linear and angular velocity components
577    /// * Error if the transform is not available or cannot be computed
578    ///
579    /// # Performance
580    /// The first lookup of a specific frame pair and time will compute the velocity and
581    /// cache the result. Subsequent lookups will use the cached result, which is much faster.
582    /// For real-time or performance-critical applications, this caching is crucial.
583    ///
584    /// # Cache Management
585    /// The cache is automatically cleared when new transforms are added. You can also
586    /// manually clear the cache with `clear_cache()` or trigger cleanup with `cleanup_cache()`.
587    pub fn lookup_velocity(
588        &self,
589        from_frame: &str,
590        to_frame: &str,
591        time: CuTime,
592        robot_clock: &RobotClock,
593    ) -> TransformResult<VelocityTransform<T>> {
594        // Identity case: same frame
595        if from_frame == to_frame {
596            return Ok(VelocityTransform::default());
597        }
598
599        // Find the path between frames
600        let path = self.find_path(from_frame, to_frame)?;
601
602        if path.is_empty() {
603            // Empty path means identity transform (zero velocity)
604            return Ok(VelocityTransform::default());
605        }
606
607        // Calculate a hash for the path (for cache lookups)
608        let path_hash = Self::compute_path_hash(&path);
609
610        // Try to get the velocity from cache
611        if let Some(cached_velocity) =
612            self.velocity_cache
613                .get(from_frame, to_frame, time, path_hash, robot_clock)
614        {
615            return Ok(cached_velocity);
616        }
617
618        // Check if it's time to clean up the cache
619        if self.velocity_cache.should_cleanup(robot_clock) {
620            self.velocity_cache.cleanup(robot_clock);
621        }
622
623        // Cache miss - compute the velocity
624
625        // Initialize zero velocity
626        let mut result = VelocityTransform::default();
627
628        // Iterate through each segment of the path
629        for (parent, child, inverse) in &path {
630            let buffer = self.get_segment_buffer(parent, child)?;
631
632            // Compute velocity for this segment
633            let segment_velocity = buffer
634                .compute_velocity_at_time(time)
635                .ok_or(TransformError::TransformTimeNotAvailable(time))?;
636
637            // Get the transform at the requested time
638            let transform = buffer
639                .get_closest_transform(time)
640                .ok_or(TransformError::TransformTimeNotAvailable(time))?;
641
642            // Apply the proper velocity transformation
643            // We need the current position for proper velocity transformation
644            let position = [T::default(); 3]; // Assume transformation at origin for simplicity
645            // A more accurate implementation would track the position
646
647            let transformed_velocity = if *inverse {
648                let inverse_transform = transform.transform.inverse();
649                crate::velocity::transform_velocity(
650                    &segment_velocity.negate(),
651                    &inverse_transform,
652                    &position,
653                )
654            } else {
655                crate::velocity::transform_velocity(
656                    &segment_velocity,
657                    &transform.transform,
658                    &position,
659                )
660            };
661
662            result.linear[0] += transformed_velocity.linear[0];
663            result.linear[1] += transformed_velocity.linear[1];
664            result.linear[2] += transformed_velocity.linear[2];
665
666            result.angular[0] += transformed_velocity.angular[0];
667            result.angular[1] += transformed_velocity.angular[1];
668            result.angular[2] += transformed_velocity.angular[2];
669        }
670
671        // Cache the computed result
672        self.velocity_cache.insert(
673            from_frame,
674            to_frame,
675            result.clone(),
676            time,
677            path_hash,
678            robot_clock,
679        );
680
681        Ok(result)
682    }
683}
684
685impl<T: Copy + Debug + Default + One + Serialize + DeserializeOwned + 'static + Neg<Output = T>>
686    Default for TransformTree<T>
687where
688    Transform3D<T>: Clone + HasInverse<T> + std::ops::Mul<Output = Transform3D<T>>,
689    T: std::ops::Add<Output = T>
690        + std::ops::Sub<Output = T>
691        + std::ops::Mul<Output = T>
692        + std::ops::Div<Output = T>
693        + std::ops::AddAssign
694        + std::ops::SubAssign
695        + num_traits::NumCast,
696{
697    fn default() -> Self {
698        Self::new()
699    }
700}
701
702#[cfg(test)]
703#[allow(deprecated)] // We intentionally test deprecated APIs for backward compatibility
704mod tests {
705    use super::*;
706    use crate::test_utils::get_translation;
707    use crate::{FrameTransform, frame_id};
708    use cu29::clock::{CuDuration, RobotClock};
709
710    // Helper function to replace assert_relative_eq
711    fn assert_approx_eq(actual: f32, expected: f32, epsilon: f32, message: &str) {
712        let diff = (actual - expected).abs();
713        assert!(
714            diff <= epsilon,
715            "{message}: expected {expected}, got {actual}, difference {diff} exceeds epsilon {epsilon}",
716        );
717    }
718
719    fn make_stamped(
720        parent: &str,
721        child: &str,
722        ts: CuDuration,
723        tf: Transform3D<f32>,
724    ) -> StampedFrameTransform<f32> {
725        let inner = FrameTransform {
726            transform: tf,
727            parent_frame: frame_id!(parent),
728            child_frame: frame_id!(child),
729        };
730        let mut stf = StampedFrameTransform::new(Some(inner));
731        stf.tov = ts.into();
732        stf
733    }
734
735    // Only use f32/f64 for tests since our inverse transform is only implemented for these types
736
737    #[test]
738    fn test_add_transform() {
739        let mut tree = TransformTree::<f32>::new();
740
741        let inner = FrameTransform {
742            transform: Transform3D::default(),
743            parent_frame: frame_id!("world"),
744            child_frame: frame_id!("robot"),
745        };
746        let mut stf = StampedFrameTransform::new(Some(inner));
747        stf.tov = CuDuration(1000).into();
748
749        assert!(tree.add_transform(&stf).is_ok());
750    }
751
752    #[test]
753    fn test_cyclic_transforms() {
754        let mut tree = TransformTree::<f32>::new();
755
756        let transform1 = make_stamped("world", "robot", 1000.into(), Transform3D::default());
757        let transform2 = make_stamped("robot", "sensor", 1000.into(), Transform3D::default());
758        let transform3 = make_stamped("sensor", "world", 1000.into(), Transform3D::default());
759
760        assert!(tree.add_transform(&transform1).is_ok());
761        assert!(tree.add_transform(&transform2).is_ok());
762
763        let result = tree.add_transform(&transform3);
764        assert!(result.is_err());
765        if let Err(e) = result {
766            assert!(matches!(e, TransformError::CyclicTransformTree));
767        }
768    }
769
770    #[test]
771    fn test_find_path() {
772        let mut tree = TransformTree::<f32>::new();
773
774        let transform1 = make_stamped("world", "robot", 1000.into(), Transform3D::default());
775        let transform2 = make_stamped("robot", "sensor", 1000.into(), Transform3D::default());
776
777        assert!(tree.add_transform(&transform1).is_ok());
778        assert!(tree.add_transform(&transform2).is_ok());
779
780        let path = tree.find_path("world", "sensor");
781        assert!(path.is_ok());
782
783        let path_vec = path.unwrap();
784        assert_eq!(path_vec.len(), 2);
785        assert_eq!(path_vec[0].0.as_str(), "world");
786        assert_eq!(path_vec[0].1.as_str(), "robot");
787        assert_eq!(path_vec[1].0.as_str(), "robot");
788        assert_eq!(path_vec[1].1.as_str(), "sensor");
789    }
790
791    #[test]
792    fn test_lookup_transform_with_inverse() {
793        let mut tree = TransformTree::<f32>::new();
794
795        let matrix = [
796            [1.0, 0.0, 0.0, 0.0],
797            [0.0, 1.0, 0.0, 0.0],
798            [0.0, 0.0, 1.0, 0.0],
799            [2.0, 3.0, 4.0, 1.0],
800        ];
801        let tf = make_stamped(
802            "world",
803            "robot",
804            CuDuration(1000),
805            Transform3D::from_matrix(matrix),
806        );
807
808        assert!(tree.add_transform(&tf).is_ok());
809
810        let clock = RobotClock::default();
811
812        let forward = tree
813            .lookup_transform("world", "robot", CuDuration(1000), &clock)
814            .unwrap();
815        assert_eq!(get_translation(&forward).0, 2.0);
816        assert_eq!(get_translation(&forward).1, 3.0);
817        assert_eq!(get_translation(&forward).2, 4.0);
818
819        let inverse = tree
820            .lookup_transform("robot", "world", CuDuration(1000), &clock)
821            .unwrap();
822        assert_eq!(get_translation(&inverse).0, -2.0);
823        assert_eq!(get_translation(&inverse).1, -3.0);
824        assert_eq!(get_translation(&inverse).2, -4.0);
825    }
826
827    #[test]
828    fn test_multi_step_transform_composition() {
829        let mut tree = TransformTree::<f32>::new();
830        let ts = CuDuration(1000);
831
832        let world_to_base = make_stamped(
833            "world",
834            "base",
835            ts,
836            Transform3D::from_matrix([
837                [1.0, 0.0, 0.0, 0.0],
838                [0.0, 1.0, 0.0, 0.0],
839                [0.0, 0.0, 1.0, 0.0],
840                [1.0, 0.0, 0.0, 1.0],
841            ]),
842        );
843
844        let base_to_arm = make_stamped(
845            "base",
846            "arm",
847            ts,
848            Transform3D::from_matrix([
849                [0.0, 1.0, 0.0, 0.0],
850                [-1.0, 0.0, 0.0, 0.0],
851                [0.0, 0.0, 1.0, 0.0],
852                [0.0, 0.0, 0.0, 1.0],
853            ]),
854        );
855
856        let arm_to_gripper = make_stamped(
857            "arm",
858            "gripper",
859            ts,
860            Transform3D::from_matrix([
861                [1.0, 0.0, 0.0, 0.0],
862                [0.0, 1.0, 0.0, 0.0],
863                [0.0, 0.0, 1.0, 0.0],
864                [0.0, 2.0, 0.0, 1.0],
865            ]),
866        );
867
868        assert!(tree.add_transform(&world_to_base).is_ok());
869        assert!(tree.add_transform(&base_to_arm).is_ok());
870        assert!(tree.add_transform(&arm_to_gripper).is_ok());
871
872        let clock = RobotClock::default();
873        let transform = tree
874            .lookup_transform("world", "gripper", ts, &clock)
875            .unwrap();
876        let epsilon = 1e-5;
877
878        let mat = transform.to_matrix();
879        assert_approx_eq(mat[0][0], 0.0, epsilon, "mat_0_0");
880        assert_approx_eq(mat[1][0], -1.0, epsilon, "mat_1_0");
881        assert_approx_eq(mat[0][1], 1.0, epsilon, "mat_0_1");
882        assert_approx_eq(mat[1][1], 0.0, epsilon, "mat_1_1");
883
884        assert_approx_eq(get_translation(&transform).0, 0.0, epsilon, "translation_x");
885        assert_approx_eq(get_translation(&transform).1, 3.0, epsilon, "translation_y");
886        assert_approx_eq(get_translation(&transform).2, 0.0, epsilon, "translation_z");
887
888        let cached = tree
889            .lookup_transform("world", "gripper", ts, &clock)
890            .unwrap();
891        for i in 0..4 {
892            for j in 0..4 {
893                assert_approx_eq(
894                    transform.to_matrix()[i][j],
895                    cached.to_matrix()[i][j],
896                    epsilon,
897                    "matrix_element",
898                );
899            }
900        }
901
902        let inverse = tree
903            .lookup_transform("gripper", "world", ts, &clock)
904            .unwrap();
905        let inv_mat = inverse.to_matrix();
906        assert_approx_eq(inv_mat[1][0], 1.0, epsilon, "inv_mat_1_0");
907        assert_approx_eq(inv_mat[0][1], -1.0, epsilon, "inv_mat_0_1");
908        assert_approx_eq(get_translation(&inverse).0, -3.0, epsilon, "translation_x");
909        assert_approx_eq(get_translation(&inverse).1, 0.0, epsilon, "translation_y");
910
911        let product = transform * inverse;
912        for i in 0..4 {
913            for j in 0..4 {
914                let expected = if i == j { 1.0 } else { 0.0 };
915                assert_approx_eq(
916                    product.to_matrix()[i][j],
917                    expected,
918                    epsilon,
919                    "matrix_element",
920                );
921            }
922        }
923    }
924
925    #[test]
926    fn test_cache_invalidation() {
927        let mut tree = TransformTree::<f32>::with_cache_settings(5, Duration::from_millis(50));
928        let ts = CuDuration(1000);
929
930        let tf = make_stamped(
931            "a",
932            "b",
933            ts,
934            Transform3D::from_matrix([
935                [1.0, 0.0, 0.0, 0.0],
936                [0.0, 1.0, 0.0, 0.0],
937                [0.0, 0.0, 1.0, 0.0],
938                [1.0, 2.0, 3.0, 1.0],
939            ]),
940        );
941
942        assert!(tree.add_transform(&tf).is_ok());
943
944        let clock = RobotClock::default();
945        let result1 = tree.lookup_transform("a", "b", ts, &clock);
946        assert!(result1.is_ok());
947
948        let transform1 = result1.unwrap();
949        assert_eq!(get_translation(&transform1).0, 1.0);
950
951        std::thread::sleep(Duration::from_millis(100));
952
953        let result2 = tree.lookup_transform("a", "b", ts, &clock);
954        assert!(result2.is_ok());
955
956        tree.clear_cache();
957
958        let result3 = tree.lookup_transform("a", "b", ts, &clock);
959        assert!(result3.is_ok());
960    }
961
962    #[test]
963    fn test_multi_step_transform_with_inverse() {
964        let mut tree = TransformTree::<f32>::new();
965        let ts = CuDuration(1000);
966
967        let world_to_robot = make_stamped(
968            "world",
969            "robot",
970            ts,
971            Transform3D::from_matrix([
972                [1.0, 0.0, 0.0, 0.0],
973                [0.0, 1.0, 0.0, 0.0],
974                [0.0, 0.0, 1.0, 0.0],
975                [1.0, 2.0, 3.0, 1.0],
976            ]),
977        );
978
979        let robot_to_camera = make_stamped(
980            "robot",
981            "camera",
982            ts,
983            Transform3D::from_matrix([
984                [0.0, 1.0, 0.0, 0.0],
985                [-1.0, 0.0, 0.0, 0.0],
986                [0.0, 0.0, 1.0, 0.0],
987                [0.5, 0.0, 0.2, 1.0],
988            ]),
989        );
990
991        assert!(tree.add_transform(&world_to_robot).is_ok());
992        assert!(tree.add_transform(&robot_to_camera).is_ok());
993
994        let clock = RobotClock::default();
995        let transform = tree
996            .lookup_transform("world", "camera", ts, &clock)
997            .unwrap();
998        let epsilon = 1e-5;
999
1000        let mat = transform.to_matrix();
1001        assert_approx_eq(mat[0][0], 0.0, epsilon, "mat_0_0");
1002        assert_approx_eq(mat[1][0], -1.0, epsilon, "mat_1_0");
1003        assert_approx_eq(mat[0][1], 1.0, epsilon, "mat_0_1");
1004        assert_approx_eq(mat[1][1], 0.0, epsilon, "mat_1_1");
1005
1006        assert_approx_eq(
1007            get_translation(&transform).0,
1008            -1.5,
1009            epsilon,
1010            "translation_x",
1011        );
1012        assert_approx_eq(get_translation(&transform).1, 1.0, epsilon, "translation_y");
1013        assert_approx_eq(get_translation(&transform).2, 3.2, epsilon, "translation_z");
1014
1015        let inverse = tree.lookup_transform("camera", "world", ts, &clock);
1016        assert!(inverse.is_ok());
1017    }
1018
1019    #[test]
1020    fn test_cache_cleanup() {
1021        let tree = TransformTree::<f32>::with_cache_settings(5, Duration::from_millis(10));
1022
1023        // Explicitly trigger cache cleanup
1024        let clock = RobotClock::default();
1025        tree.cleanup_cache(&clock);
1026    }
1027
1028    #[test]
1029    fn test_lookup_velocity() {
1030        let mut tree = TransformTree::<f32>::new();
1031
1032        let w2b_1 = make_stamped(
1033            "world",
1034            "base",
1035            CuDuration(1_000_000_000),
1036            Transform3D::from_matrix([
1037                [1.0, 0.0, 0.0, 0.0],
1038                [0.0, 1.0, 0.0, 0.0],
1039                [0.0, 0.0, 1.0, 0.0],
1040                [0.0, 0.0, 0.0, 1.0],
1041            ]),
1042        );
1043
1044        let w2b_2 = make_stamped(
1045            "world",
1046            "base",
1047            CuDuration(2_000_000_000),
1048            Transform3D::from_matrix([
1049                [1.0, 0.0, 0.0, 0.0],
1050                [0.0, 1.0, 0.0, 0.0],
1051                [0.0, 0.0, 1.0, 0.0],
1052                [1.0, 0.0, 0.0, 1.0],
1053            ]),
1054        );
1055
1056        let b2s_1 = make_stamped(
1057            "base",
1058            "sensor",
1059            CuDuration(1_000_000_000),
1060            Transform3D::from_matrix([
1061                [1.0, 0.0, 0.0, 0.0],
1062                [0.0, 1.0, 0.0, 0.0],
1063                [0.0, 0.0, 1.0, 0.0],
1064                [0.0, 0.0, 0.0, 1.0],
1065            ]),
1066        );
1067
1068        let b2s_2 = make_stamped(
1069            "base",
1070            "sensor",
1071            CuDuration(2_000_000_000),
1072            Transform3D::from_matrix([
1073                [1.0, 0.0, 0.0, 0.0],
1074                [0.0, 1.0, 0.0, 0.0],
1075                [0.0, 0.0, 1.0, 0.0],
1076                [0.0, 2.0, 0.0, 1.0],
1077            ]),
1078        );
1079
1080        tree.add_transform(&w2b_1).unwrap();
1081        tree.add_transform(&w2b_2).unwrap();
1082        tree.add_transform(&b2s_1).unwrap();
1083        tree.add_transform(&b2s_2).unwrap();
1084
1085        let clock = RobotClock::default();
1086        let velocity = tree.lookup_velocity("world", "sensor", CuDuration(1_500_000_000), &clock);
1087        assert!(velocity.is_ok());
1088
1089        let vel = velocity.unwrap();
1090        let epsilon = 0.1;
1091        assert_approx_eq(vel.linear[0], 1.0, epsilon, "linear_velocity_0");
1092        assert_approx_eq(vel.linear[1], 2.0, epsilon, "linear_velocity_1");
1093        assert_approx_eq(vel.linear[2], 0.0, epsilon, "linear_velocity_2");
1094    }
1095
1096    #[test]
1097    fn test_velocity_with_rotation() {
1098        let mut tree = TransformTree::<f32>::new();
1099
1100        let ts1 = CuDuration(1_000_000_000);
1101        let ts2 = CuDuration(2_000_000_000);
1102
1103        let w2b_1 = make_stamped(
1104            "world",
1105            "base",
1106            ts1,
1107            Transform3D::from_matrix([
1108                [1.0, 0.0, 0.0, 0.0],
1109                [0.0, 1.0, 0.0, 0.0],
1110                [0.0, 0.0, 1.0, 0.0],
1111                [0.0, 0.0, 0.0, 1.0],
1112            ]),
1113        );
1114
1115        let b2s_1 = make_stamped(
1116            "base",
1117            "sensor",
1118            ts1,
1119            Transform3D::from_matrix([
1120                [0.0, 1.0, 0.0, 0.0],
1121                [-1.0, 0.0, 0.0, 0.0],
1122                [0.0, 0.0, 1.0, 0.0],
1123                [1.0, 0.0, 0.0, 1.0],
1124            ]),
1125        );
1126
1127        let w2b_2 = make_stamped(
1128            "world",
1129            "base",
1130            ts2,
1131            Transform3D::from_matrix([
1132                [1.0, 0.0, 0.0, 0.0],
1133                [0.0, 1.0, 0.0, 0.0],
1134                [0.0, 0.0, 1.0, 0.0],
1135                [1.0, 0.0, 0.0, 1.0],
1136            ]),
1137        );
1138
1139        let b2s_2 = make_stamped(
1140            "base",
1141            "sensor",
1142            ts2,
1143            Transform3D::from_matrix([
1144                [0.0, 1.0, 0.0, 0.0],
1145                [-1.0, 0.0, 0.0, 0.0],
1146                [0.0, 0.0, 1.0, 0.0],
1147                [1.0, 0.0, 0.0, 1.0],
1148            ]),
1149        );
1150
1151        tree.add_transform(&w2b_1).unwrap();
1152        tree.add_transform(&w2b_2).unwrap();
1153        tree.add_transform(&b2s_1).unwrap();
1154        tree.add_transform(&b2s_2).unwrap();
1155
1156        let clock = RobotClock::default();
1157        let mid_ts = CuDuration(1_500_000_000);
1158
1159        let velocity = tree.lookup_velocity("world", "sensor", mid_ts, &clock);
1160        assert!(velocity.is_ok());
1161        let vel = velocity.unwrap();
1162        let epsilon = 0.2;
1163        assert_approx_eq(vel.linear[0], 1.0, epsilon, "linear_velocity_0");
1164        assert_approx_eq(vel.linear[1], 0.0, epsilon, "linear_velocity_1");
1165        assert_approx_eq(vel.linear[2], 0.0, epsilon, "linear_velocity_2");
1166
1167        let reverse = tree.lookup_velocity("sensor", "world", mid_ts, &clock);
1168        assert!(reverse.is_ok());
1169        let rev_vel = reverse.unwrap();
1170        assert_approx_eq(rev_vel.linear[0], -1.0, epsilon, "linear_velocity_0");
1171        assert_approx_eq(rev_vel.linear[1], 0.0, epsilon, "linear_velocity_1");
1172        assert_approx_eq(rev_vel.linear[2], 0.0, epsilon, "linear_velocity_2");
1173    }
1174
1175    #[test]
1176    fn test_velocity_with_angular_motion() {
1177        let mut tree = TransformTree::<f32>::new();
1178        let ts1 = CuDuration(1_000_000_000);
1179        let ts2 = CuDuration(2_000_000_000);
1180
1181        let w2b_1 = make_stamped(
1182            "world",
1183            "base",
1184            ts1,
1185            Transform3D::from_matrix([
1186                [1.0, 0.0, 0.0, 0.0],
1187                [0.0, 1.0, 0.0, 0.0],
1188                [0.0, 0.0, 1.0, 0.0],
1189                [0.0, 0.0, 0.0, 1.0],
1190            ]),
1191        );
1192
1193        let w2b_2 = make_stamped(
1194            "world",
1195            "base",
1196            ts2,
1197            Transform3D::from_matrix([
1198                [0.0, 1.0, 0.0, 0.0],
1199                [-1.0, 0.0, 0.0, 0.0],
1200                [0.0, 0.0, 1.0, 0.0],
1201                [0.0, 0.0, 0.0, 1.0],
1202            ]),
1203        );
1204
1205        let b2s_1 = make_stamped(
1206            "base",
1207            "sensor",
1208            ts1,
1209            Transform3D::from_matrix([
1210                [1.0, 0.0, 0.0, 0.0],
1211                [0.0, 1.0, 0.0, 0.0],
1212                [0.0, 0.0, 1.0, 0.0],
1213                [1.0, 0.0, 0.0, 1.0],
1214            ]),
1215        );
1216
1217        let b2s_2 = make_stamped(
1218            "base",
1219            "sensor",
1220            ts2,
1221            Transform3D::from_matrix([
1222                [1.0, 0.0, 0.0, 0.0],
1223                [0.0, 1.0, 0.0, 0.0],
1224                [0.0, 0.0, 1.0, 0.0],
1225                [1.0, 0.0, 0.0, 1.0],
1226            ]),
1227        );
1228
1229        tree.add_transform(&w2b_1).unwrap();
1230        tree.add_transform(&w2b_2).unwrap();
1231        tree.add_transform(&b2s_1).unwrap();
1232        tree.add_transform(&b2s_2).unwrap();
1233
1234        let clock = RobotClock::default();
1235        let vel = tree
1236            .lookup_velocity("world", "sensor", CuDuration(1_500_000_000), &clock)
1237            .unwrap();
1238
1239        let epsilon = 0.1;
1240        assert_approx_eq(vel.angular[0], 0.0, epsilon, "angular_velocity_0");
1241        assert_approx_eq(vel.angular[1], 0.0, epsilon, "angular_velocity_1");
1242        assert_approx_eq(vel.angular[2], -1.0, epsilon, "angular_velocity_2");
1243
1244        assert!(!vel.linear[0].is_nan());
1245        assert!(!vel.linear[1].is_nan());
1246        assert!(!vel.linear[2].is_nan());
1247
1248        assert!(!vel.angular[0].is_nan());
1249        assert!(!vel.angular[1].is_nan());
1250        assert!(!vel.angular[2].is_nan());
1251    }
1252
1253    #[test]
1254    fn test_velocity_cache() {
1255        let mut tree = TransformTree::<f32>::new();
1256        let ts1 = CuDuration(1_000_000_000);
1257        let ts2 = CuDuration(2_000_000_000);
1258
1259        let tf1 = make_stamped(
1260            "world",
1261            "robot",
1262            ts1,
1263            Transform3D::from_matrix([
1264                [1.0, 0.0, 0.0, 0.0],
1265                [0.0, 1.0, 0.0, 0.0],
1266                [0.0, 0.0, 1.0, 0.0],
1267                [0.0, 0.0, 0.0, 1.0],
1268            ]),
1269        );
1270
1271        let tf2 = make_stamped(
1272            "world",
1273            "robot",
1274            ts2,
1275            Transform3D::from_matrix([
1276                [1.0, 0.0, 0.0, 0.0],
1277                [0.0, 1.0, 0.0, 0.0],
1278                [0.0, 0.0, 1.0, 0.0],
1279                [2.0, 0.0, 0.0, 1.0],
1280            ]),
1281        );
1282
1283        tree.add_transform(&tf1).unwrap();
1284        tree.add_transform(&tf2).unwrap();
1285
1286        let clock = RobotClock::default();
1287
1288        let start_time = std::time::Instant::now();
1289        let velocity1 = tree.lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock);
1290        let first_lookup_time = start_time.elapsed();
1291
1292        assert!(velocity1.is_ok());
1293        let vel1 = velocity1.unwrap();
1294        assert_approx_eq(vel1.linear[0], 2.0, 0.01, "linear_velocity_0");
1295
1296        let start_time = std::time::Instant::now();
1297        let velocity2 = tree.lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock);
1298        let second_lookup_time = start_time.elapsed();
1299
1300        assert!(velocity2.is_ok());
1301        let vel2 = velocity2.unwrap();
1302        assert_approx_eq(vel2.linear[0], 2.0, 0.01, "linear_velocity_0");
1303
1304        tree.clear_cache();
1305
1306        let start_time = std::time::Instant::now();
1307        let velocity3 = tree.lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock);
1308        let third_lookup_time = start_time.elapsed();
1309
1310        assert!(velocity3.is_ok());
1311
1312        println!("First lookup: {first_lookup_time:?}");
1313        println!("Second lookup (cached): {second_lookup_time:?}");
1314        println!("Third lookup (after cache clear): {third_lookup_time:?}");
1315    }
1316
1317    #[test]
1318    fn test_velocity_cache_invalidation() {
1319        let mut tree = TransformTree::<f32>::new();
1320        let ts1 = CuDuration(1_000_000_000);
1321        let ts2 = CuDuration(2_000_000_000);
1322        let ts3 = CuDuration(3_000_000_000);
1323
1324        let tf1 = make_stamped(
1325            "world",
1326            "robot",
1327            ts1,
1328            Transform3D::from_matrix([
1329                [1.0, 0.0, 0.0, 0.0],
1330                [0.0, 1.0, 0.0, 0.0],
1331                [0.0, 0.0, 1.0, 0.0],
1332                [0.0, 0.0, 0.0, 1.0],
1333            ]),
1334        );
1335
1336        let tf2 = make_stamped(
1337            "world",
1338            "robot",
1339            ts2,
1340            Transform3D::from_matrix([
1341                [1.0, 0.0, 0.0, 0.0],
1342                [0.0, 1.0, 0.0, 0.0],
1343                [0.0, 0.0, 1.0, 0.0],
1344                [1.0, 0.0, 0.0, 1.0],
1345            ]),
1346        );
1347
1348        let tf3 = make_stamped(
1349            "world",
1350            "robot",
1351            ts3,
1352            Transform3D::from_matrix([
1353                [1.0, 0.0, 0.0, 0.0],
1354                [0.0, 1.0, 0.0, 0.0],
1355                [0.0, 0.0, 1.0, 0.0],
1356                [3.0, 0.0, 0.0, 1.0],
1357            ]),
1358        );
1359
1360        tree.add_transform(&tf1).unwrap();
1361        tree.add_transform(&tf2).unwrap();
1362
1363        let clock = RobotClock::default();
1364        let velocity1 = tree
1365            .lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock)
1366            .unwrap();
1367        assert_approx_eq(velocity1.linear[0], 1.0, 0.01, "linear_velocity_0");
1368
1369        tree.add_transform(&tf3).unwrap();
1370
1371        let velocity2 = tree
1372            .lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock)
1373            .unwrap();
1374        assert_approx_eq(velocity2.linear[0], 1.0, 0.01, "linear_velocity_0");
1375
1376        let velocity3 = tree
1377            .lookup_velocity("world", "robot", CuDuration(2_500_000_000), &clock)
1378            .unwrap();
1379        assert_approx_eq(velocity3.linear[0], 2.0, 0.01, "linear_velocity_0");
1380    }
1381}