Skip to main content

cu_transform/
tree.rs

1use crate::FrameIdString;
2use crate::error::{TransformError, TransformResult};
3use crate::transform::{StampedTransform, TransformBuffer, TransformStore};
4use crate::transform_payload::StampedFrameTransform;
5use crate::velocity::VelocityTransform;
6use crate::velocity_cache::VelocityTransformCache;
7use cu_spatial_payloads::Transform3D;
8use cu29::clock::{CuTime, RobotClock, Tov};
9use cu29::prelude::CuMsgPayload;
10use dashmap::DashMap;
11use petgraph::algo::dijkstra;
12use petgraph::graph::{DiGraph, NodeIndex};
13use serde::Serialize;
14use serde::de::DeserializeOwned;
15use std::collections::HashMap;
16use std::fmt::Debug;
17use std::ops::Neg;
18use std::time::Duration;
19
20/// Trait for types that can compute their inverse transformation
21pub trait HasInverse<T: Copy + Debug + 'static> {
22    fn inverse(&self) -> Self;
23}
24
25impl HasInverse<f32> for Transform3D<f32> {
26    fn inverse(&self) -> Self {
27        self.inverse()
28    }
29}
30
31impl HasInverse<f64> for Transform3D<f64> {
32    fn inverse(&self) -> Self {
33        self.inverse()
34    }
35}
36
37/// The cache entry for a transform query
38#[derive(Clone)]
39struct TransformCacheEntry<T: Copy + Debug + 'static> {
40    /// The cached transform result
41    transform: Transform3D<T>,
42    /// The time for which this transform was calculated
43    time: CuTime,
44    /// When this cache entry was last accessed
45    last_access: CuTime,
46    /// Path used to calculate this transform
47    path_hash: u64,
48}
49
50/// A cache for transforms to avoid recalculating frequently accessed paths
51struct TransformCache<T: Copy + Debug + 'static> {
52    /// Map from (source, target) frames to cached transforms
53    entries: DashMap<(FrameIdString, FrameIdString), TransformCacheEntry<T>>,
54    /// Maximum size of the cache
55    max_size: usize,
56    /// Maximum age of cache entries before invalidation (in nanoseconds)
57    max_age_nanos: u64,
58    /// Last time the cache was cleaned up (needs to be mutable)
59    last_cleanup_cell: std::sync::Mutex<CuTime>,
60    /// Cleanup interval (in nanoseconds)
61    cleanup_interval_nanos: u64,
62}
63
64impl<T: Copy + Debug + 'static> TransformCache<T> {
65    fn new(max_size: usize, max_age: Duration) -> Self {
66        Self {
67            entries: DashMap::with_capacity(max_size),
68            max_size,
69            max_age_nanos: max_age.as_nanos() as u64,
70            last_cleanup_cell: std::sync::Mutex::new(CuTime::from(0u64)),
71            cleanup_interval_nanos: 5_000_000_000, // Clean every 5 seconds
72        }
73    }
74
75    /// Get a cached transform if it exists and is still valid
76    fn get(
77        &self,
78        from: &str,
79        to: &str,
80        time: CuTime,
81        path_hash: u64,
82        robot_clock: &RobotClock,
83    ) -> Option<Transform3D<T>> {
84        let key = (
85            FrameIdString::from(from).expect("Frame name too long"),
86            FrameIdString::from(to).expect("Frame name too long"),
87        );
88
89        if let Some(mut entry) = self.entries.get_mut(&key) {
90            let now = robot_clock.now();
91
92            // Check if the cache entry is for the same time and path
93            if entry.time == time && entry.path_hash == path_hash {
94                // Check if the entry is still valid (not too old)
95                let age = now.as_nanos().saturating_sub(entry.last_access.as_nanos());
96                if age <= self.max_age_nanos {
97                    // Update last access time
98                    entry.last_access = now;
99                    return Some(entry.transform);
100                }
101            }
102        }
103
104        None
105    }
106
107    /// Add a new transform to the cache
108    fn insert(
109        &self,
110        from: &str,
111        to: &str,
112        transform: Transform3D<T>,
113        time: CuTime,
114        path_hash: u64,
115        robot_clock: &RobotClock,
116    ) {
117        let now = robot_clock.now();
118        let key = (
119            FrameIdString::from(from).expect("Frame name too long"),
120            FrameIdString::from(to).expect("Frame name too long"),
121        );
122
123        // If the cache is at capacity, remove the oldest entry
124        if self.entries.len() >= self.max_size {
125            let oldest_key = self
126                .entries
127                .iter()
128                .min_by_key(|entry| entry.last_access)
129                .map(|entry| *entry.key());
130            if let Some(key_to_remove) = oldest_key {
131                self.entries.remove(&key_to_remove);
132            }
133        }
134
135        // Insert the new entry
136        self.entries.insert(
137            key,
138            TransformCacheEntry {
139                transform,
140                time,
141                last_access: now,
142                path_hash,
143            },
144        );
145    }
146
147    /// Check if it's time to clean up the cache
148    fn should_cleanup(&self, robot_clock: &RobotClock) -> bool {
149        let now = robot_clock.now();
150        let last_cleanup = *self.last_cleanup_cell.lock().unwrap();
151        let elapsed = now.as_nanos().saturating_sub(last_cleanup.as_nanos());
152        elapsed >= self.cleanup_interval_nanos
153    }
154
155    /// Clear old entries from the cache
156    fn cleanup(&self, robot_clock: &RobotClock) {
157        let now = robot_clock.now();
158        let mut keys_to_remove = Vec::new();
159
160        // Identify keys to remove
161        for entry in self.entries.iter() {
162            let age = now.as_nanos().saturating_sub(entry.last_access.as_nanos());
163            if age > self.max_age_nanos {
164                keys_to_remove.push(*entry.key());
165            }
166        }
167
168        // Remove expired entries
169        for key in keys_to_remove {
170            self.entries.remove(&key);
171        }
172
173        // Update last cleanup time
174        *self.last_cleanup_cell.lock().unwrap() = now;
175    }
176
177    /// Clear all entries
178    fn clear(&self) {
179        self.entries.clear();
180    }
181}
182
183pub struct TransformTree<T: Copy + Debug + Default + 'static> {
184    graph: DiGraph<FrameIdString, ()>,
185    frame_indices: HashMap<FrameIdString, NodeIndex>,
186    transform_store: TransformStore<T>,
187    // Concurrent cache for transform lookups
188    cache: TransformCache<T>,
189    // Concurrent cache for velocity transform lookups
190    velocity_cache: VelocityTransformCache<T>,
191}
192
193/// Trait for types that can provide a value representing "one"
194pub trait One {
195    /// Returns a value representing "one" for this type
196    fn one() -> Self;
197}
198
199// Implement One for common numeric types
200impl One for f32 {
201    fn one() -> Self {
202        1.0
203    }
204}
205
206impl One for f64 {
207    fn one() -> Self {
208        1.0
209    }
210}
211
212impl One for i32 {
213    fn one() -> Self {
214        1
215    }
216}
217
218impl One for i64 {
219    fn one() -> Self {
220        1
221    }
222}
223
224impl One for u32 {
225    fn one() -> Self {
226        1
227    }
228}
229
230impl One for u64 {
231    fn one() -> Self {
232        1
233    }
234}
235
236// We need to limit T to types where Transform3D<T> has Clone and inverse method
237// and now we also require T to implement One
238impl<T: Copy + Debug + Default + One + Serialize + DeserializeOwned + 'static + Neg<Output = T>>
239    TransformTree<T>
240where
241    Transform3D<T>: Clone + HasInverse<T> + std::ops::Mul<Output = Transform3D<T>>,
242    T: std::ops::Add<Output = T>
243        + std::ops::Sub<Output = T>
244        + std::ops::Mul<Output = T>
245        + std::ops::Div<Output = T>
246        + std::ops::AddAssign
247        + std::ops::SubAssign
248        + num_traits::NumCast,
249{
250    /// Default cache size (number of transforms to cache)
251    const DEFAULT_CACHE_SIZE: usize = 100;
252
253    /// Default cache entry lifetime (5 seconds)
254    const DEFAULT_CACHE_AGE: Duration = Duration::from_secs(5);
255
256    /// Create a new transform tree with default settings
257    pub fn new() -> Self {
258        Self {
259            graph: DiGraph::new(),
260            frame_indices: HashMap::new(),
261            transform_store: TransformStore::new(),
262            cache: TransformCache::new(Self::DEFAULT_CACHE_SIZE, Self::DEFAULT_CACHE_AGE),
263            velocity_cache: VelocityTransformCache::new(
264                Self::DEFAULT_CACHE_SIZE,
265                Self::DEFAULT_CACHE_AGE.as_nanos() as u64,
266            ),
267        }
268    }
269
270    /// Create a new transform tree with custom cache settings
271    pub fn with_cache_settings(cache_size: usize, cache_age: Duration) -> Self {
272        Self {
273            graph: DiGraph::new(),
274            frame_indices: HashMap::new(),
275            transform_store: TransformStore::new(),
276            cache: TransformCache::new(cache_size, cache_age),
277            velocity_cache: VelocityTransformCache::new(cache_size, cache_age.as_nanos() as u64),
278        }
279    }
280
281    /// Clear the transform cache
282    pub fn clear_cache(&self) {
283        self.cache.clear();
284        self.velocity_cache.clear();
285    }
286
287    /// Perform scheduled cache cleanup operation
288    pub fn cleanup_cache(&self, robot_clock: &RobotClock) {
289        self.cache.cleanup(robot_clock);
290        self.velocity_cache.cleanup(robot_clock);
291    }
292
293    /// Creates an identity transform matrix in a type-safe way
294    fn create_identity_transform() -> Transform3D<T> {
295        let mut mat = [[T::default(); 4]; 4];
296
297        // Set the diagonal elements to one
298        let one = T::one();
299        mat[0][0] = one;
300        mat[1][1] = one;
301        mat[2][2] = one;
302        mat[3][3] = one;
303
304        Transform3D::from_matrix(mat)
305    }
306
307    fn ensure_frame(&mut self, frame_id: &str) -> NodeIndex {
308        let frame_id_string = FrameIdString::from(frame_id).expect("Frame name too long");
309        *self
310            .frame_indices
311            .entry(frame_id_string)
312            .or_insert_with(|| self.graph.add_node(frame_id_string))
313    }
314
315    fn get_segment_buffer(
316        &self,
317        parent: &FrameIdString,
318        child: &FrameIdString,
319    ) -> TransformResult<TransformBuffer<T>> {
320        self.transform_store
321            .get_buffer(parent, child)
322            .ok_or_else(|| TransformError::TransformNotFound {
323                from: parent.to_string(),
324                to: child.to_string(),
325            })
326    }
327
328    /// add a transform to the tree.
329    pub fn add_transform(&mut self, sft: &StampedFrameTransform<T>) -> TransformResult<()>
330    where
331        T: CuMsgPayload,
332    {
333        let transform_msg = sft.payload().ok_or_else(|| {
334            TransformError::Unknown("Failed to get transform payload".to_string())
335        })?;
336
337        let timestamp = match sft.tov {
338            Tov::Time(time) => time,
339            Tov::Range(range) => range.start, // Use start of range
340            _ => {
341                return Err(TransformError::Unknown(
342                    "Invalid Time of Validity".to_string(),
343                ));
344            }
345        };
346
347        // Ensure frames exist in the graph
348        let parent_idx = self.ensure_frame(&transform_msg.parent_frame);
349        let child_idx = self.ensure_frame(&transform_msg.child_frame);
350
351        // Check for cycles
352        if self.would_create_cycle(parent_idx, child_idx) {
353            return Err(TransformError::CyclicTransformTree);
354        }
355
356        // Add edge if it doesn't exist
357        if !self.graph.contains_edge(parent_idx, child_idx) {
358            self.graph.add_edge(parent_idx, child_idx, ());
359        }
360
361        // Clear velocity cache since we're adding a transform that could change velocities
362        self.velocity_cache.clear();
363
364        // Create StampedTransform for the store (internal implementation detail)
365        let stamped = StampedTransform {
366            transform: transform_msg.transform,
367            stamp: timestamp,
368            parent_frame: transform_msg.parent_frame,
369            child_frame: transform_msg.child_frame,
370        };
371
372        // Add transform to the store
373        self.transform_store.add_transform(stamped);
374        Ok(())
375    }
376
377    fn would_create_cycle(&self, parent: NodeIndex, child: NodeIndex) -> bool {
378        if self.graph.contains_edge(parent, child) {
379            return false;
380        }
381
382        matches!(dijkstra(&self.graph, child, Some(parent), |_| 1), result if result.contains_key(&parent))
383    }
384
385    pub fn find_path(
386        &self,
387        from_frame: &str,
388        to_frame: &str,
389    ) -> TransformResult<Vec<(FrameIdString, FrameIdString, bool)>> {
390        // If frames are the same, return empty path (identity transform)
391        if from_frame == to_frame {
392            return Ok(Vec::new());
393        }
394
395        let from_frame_id = FrameIdString::from(from_frame).expect("Frame name too long");
396        let from_idx = self
397            .frame_indices
398            .get(&from_frame_id)
399            .ok_or(TransformError::FrameNotFound(from_frame.to_string()))?;
400
401        let to_frame_id = FrameIdString::from(to_frame).expect("Frame name too long");
402        let to_idx = self
403            .frame_indices
404            .get(&to_frame_id)
405            .ok_or(TransformError::FrameNotFound(to_frame.to_string()))?;
406
407        // Create an undirected version of the graph to find any path (forward or inverse)
408        let mut undirected_graph = self.graph.clone();
409
410        // Add reverse edges for every existing edge to make it undirected
411        let edges: Vec<_> = self.graph.edge_indices().collect();
412        for edge_idx in edges {
413            let (a, b) = self.graph.edge_endpoints(edge_idx).unwrap();
414            if !undirected_graph.contains_edge(b, a) {
415                undirected_graph.add_edge(b, a, ());
416            }
417        }
418
419        // Now find path in undirected graph
420        let path = dijkstra(&undirected_graph, *from_idx, Some(*to_idx), |_| 1);
421
422        if !path.contains_key(to_idx) {
423            return Err(TransformError::TransformNotFound {
424                from: from_frame.to_string(),
425                to: to_frame.to_string(),
426            });
427        }
428
429        // Reconstruct the path
430        let mut current = *to_idx;
431        let mut path_nodes = vec![current];
432
433        while current != *from_idx {
434            let mut found_next = false;
435
436            // Try all neighbors in undirected graph
437            for neighbor in undirected_graph.neighbors(current) {
438                if path.contains_key(&neighbor) && path[&neighbor] < path[&current] {
439                    current = neighbor;
440                    path_nodes.push(current);
441                    found_next = true;
442                    break;
443                }
444            }
445
446            if !found_next {
447                return Err(TransformError::TransformNotFound {
448                    from: from_frame.to_string(),
449                    to: to_frame.to_string(),
450                });
451            }
452        }
453
454        path_nodes.reverse();
455
456        // Convert node path to edge path with direction information
457        let mut path_edges = Vec::new();
458        for i in 0..path_nodes.len() - 1 {
459            let parent_idx = path_nodes[i];
460            let child_idx = path_nodes[i + 1];
461
462            let parent_frame = self.graph[parent_idx];
463            let child_frame = self.graph[child_idx];
464
465            // Check if this is a forward edge in original directed graph
466            let is_forward = self.graph.contains_edge(parent_idx, child_idx);
467
468            if is_forward {
469                // Forward edge: parent -> child
470                path_edges.push((parent_frame, child_frame, false));
471            } else {
472                // Inverse edge: child <- parent (we need child -> parent)
473                path_edges.push((child_frame, parent_frame, true));
474            }
475        }
476
477        Ok(path_edges)
478    }
479
480    /// Compute a simple hash value for a path to use as a cache key
481    fn compute_path_hash(path: &[(FrameIdString, FrameIdString, bool)]) -> u64 {
482        use std::collections::hash_map::DefaultHasher;
483        use std::hash::{Hash, Hasher};
484
485        let mut hasher = DefaultHasher::new();
486
487        for (parent, child, inverse) in path {
488            parent.hash(&mut hasher);
489            child.hash(&mut hasher);
490            inverse.hash(&mut hasher);
491        }
492
493        hasher.finish()
494    }
495
496    pub fn lookup_transform(
497        &self,
498        from_frame: &str,
499        to_frame: &str,
500        time: CuTime,
501        robot_clock: &RobotClock,
502    ) -> TransformResult<Transform3D<T>> {
503        // Identity case: same frame
504        if from_frame == to_frame {
505            return Ok(Self::create_identity_transform());
506        }
507
508        // Find the path between frames
509        let path = self.find_path(from_frame, to_frame)?;
510
511        if path.is_empty() {
512            // Empty path is another case for identity transform
513            return Ok(Self::create_identity_transform());
514        }
515
516        // Calculate a hash for the path (for cache lookups)
517        let path_hash = Self::compute_path_hash(&path);
518
519        // Try to get the transform from cache - concurrent map allows lock-free reads
520        if let Some(cached_transform) =
521            self.cache
522                .get(from_frame, to_frame, time, path_hash, robot_clock)
523        {
524            return Ok(cached_transform);
525        }
526
527        // Check if it's time to clean up the cache
528        if self.cache.should_cleanup(robot_clock) {
529            self.cache.cleanup(robot_clock);
530        }
531
532        // Cache miss - compute the transform
533
534        // Compose multiple transforms along the path
535        let mut result = Self::create_identity_transform();
536
537        // Iterate through each segment of the path
538        for (parent, child, inverse) in &path {
539            let buffer = self.get_segment_buffer(parent, child)?;
540
541            let transform = buffer
542                .get_closest_transform(time)
543                .ok_or(TransformError::TransformTimeNotAvailable(time))?;
544
545            // Note: In transform composition, the right-most transform is applied first.
546            let transform_to_apply = if *inverse {
547                transform.transform.inverse()
548            } else {
549                transform.transform
550            };
551            result = transform_to_apply * result;
552        }
553
554        // Cache the computed result
555        self.cache
556            .insert(from_frame, to_frame, result, time, path_hash, robot_clock);
557
558        Ok(result)
559    }
560
561    /// Look up the velocity of a frame at a specific time
562    ///
563    /// This computes the velocity by differentiating transforms over time.
564    /// Returns the velocity expressed in the target frame.
565    ///
566    /// Results are automatically cached for improved performance. The cache is
567    /// invalidated when new transforms are added or when cache entries expire based
568    /// on their age. The cache significantly improves performance for repeated lookups
569    /// of the same frames and times.
570    ///
571    /// # Arguments
572    /// * `from_frame` - The source frame
573    /// * `to_frame` - The target frame
574    /// * `time` - The time at which to compute the velocity
575    ///
576    /// # Returns
577    /// * A VelocityTransform containing linear and angular velocity components
578    /// * Error if the transform is not available or cannot be computed
579    ///
580    /// # Performance
581    /// The first lookup of a specific frame pair and time will compute the velocity and
582    /// cache the result. Subsequent lookups will use the cached result, which is much faster.
583    /// For real-time or performance-critical applications, this caching is crucial.
584    ///
585    /// # Cache Management
586    /// The cache is automatically cleared when new transforms are added. You can also
587    /// manually clear the cache with `clear_cache()` or trigger cleanup with `cleanup_cache()`.
588    pub fn lookup_velocity(
589        &self,
590        from_frame: &str,
591        to_frame: &str,
592        time: CuTime,
593        robot_clock: &RobotClock,
594    ) -> TransformResult<VelocityTransform<T>> {
595        // Identity case: same frame
596        if from_frame == to_frame {
597            return Ok(VelocityTransform::default());
598        }
599
600        // Find the path between frames
601        let path = self.find_path(from_frame, to_frame)?;
602
603        if path.is_empty() {
604            // Empty path means identity transform (zero velocity)
605            return Ok(VelocityTransform::default());
606        }
607
608        // Calculate a hash for the path (for cache lookups)
609        let path_hash = Self::compute_path_hash(&path);
610
611        // Try to get the velocity from cache
612        if let Some(cached_velocity) =
613            self.velocity_cache
614                .get(from_frame, to_frame, time, path_hash, robot_clock)
615        {
616            return Ok(cached_velocity);
617        }
618
619        // Check if it's time to clean up the cache
620        if self.velocity_cache.should_cleanup(robot_clock) {
621            self.velocity_cache.cleanup(robot_clock);
622        }
623
624        // Cache miss - compute the velocity
625
626        // Initialize zero velocity
627        let mut result = VelocityTransform::default();
628
629        // Iterate through each segment of the path
630        for (parent, child, inverse) in &path {
631            let buffer = self.get_segment_buffer(parent, child)?;
632
633            // Compute velocity for this segment
634            let segment_velocity = buffer
635                .compute_velocity_at_time(time)
636                .ok_or(TransformError::TransformTimeNotAvailable(time))?;
637
638            // Get the transform at the requested time
639            let transform = buffer
640                .get_closest_transform(time)
641                .ok_or(TransformError::TransformTimeNotAvailable(time))?;
642
643            // Apply the proper velocity transformation
644            // We need the current position for proper velocity transformation
645            let position = [T::default(); 3]; // Assume transformation at origin for simplicity
646            // A more accurate implementation would track the position
647
648            let transformed_velocity = if *inverse {
649                let inverse_transform = transform.transform.inverse();
650                crate::velocity::transform_velocity(
651                    &segment_velocity.negate(),
652                    &inverse_transform,
653                    &position,
654                )
655            } else {
656                crate::velocity::transform_velocity(
657                    &segment_velocity,
658                    &transform.transform,
659                    &position,
660                )
661            };
662
663            result.linear[0] += transformed_velocity.linear[0];
664            result.linear[1] += transformed_velocity.linear[1];
665            result.linear[2] += transformed_velocity.linear[2];
666
667            result.angular[0] += transformed_velocity.angular[0];
668            result.angular[1] += transformed_velocity.angular[1];
669            result.angular[2] += transformed_velocity.angular[2];
670        }
671
672        // Cache the computed result
673        self.velocity_cache.insert(
674            from_frame,
675            to_frame,
676            result.clone(),
677            time,
678            path_hash,
679            robot_clock,
680        );
681
682        Ok(result)
683    }
684}
685
686impl<T: Copy + Debug + Default + One + Serialize + DeserializeOwned + 'static + Neg<Output = T>>
687    Default for TransformTree<T>
688where
689    Transform3D<T>: Clone + HasInverse<T> + std::ops::Mul<Output = Transform3D<T>>,
690    T: std::ops::Add<Output = T>
691        + std::ops::Sub<Output = T>
692        + std::ops::Mul<Output = T>
693        + std::ops::Div<Output = T>
694        + std::ops::AddAssign
695        + std::ops::SubAssign
696        + num_traits::NumCast,
697{
698    fn default() -> Self {
699        Self::new()
700    }
701}
702
703#[cfg(test)]
704#[allow(deprecated)] // We intentionally test deprecated APIs for backward compatibility
705mod tests {
706    use super::*;
707    use crate::test_utils::get_translation;
708    use crate::{FrameTransform, frame_id};
709    use cu29::clock::{CuDuration, RobotClock};
710
711    // Helper function to replace assert_relative_eq
712    fn assert_approx_eq(actual: f32, expected: f32, epsilon: f32, message: &str) {
713        let diff = (actual - expected).abs();
714        assert!(
715            diff <= epsilon,
716            "{message}: expected {expected}, got {actual}, difference {diff} exceeds epsilon {epsilon}",
717        );
718    }
719
720    fn make_stamped(
721        parent: &str,
722        child: &str,
723        ts: CuDuration,
724        tf: Transform3D<f32>,
725    ) -> StampedFrameTransform<f32> {
726        let inner = FrameTransform {
727            transform: tf,
728            parent_frame: frame_id!(parent),
729            child_frame: frame_id!(child),
730        };
731        let mut stf = StampedFrameTransform::new(Some(inner));
732        stf.tov = ts.into();
733        stf
734    }
735
736    // Only use f32/f64 for tests since our inverse transform is only implemented for these types
737
738    #[test]
739    fn test_add_transform() {
740        let mut tree = TransformTree::<f32>::new();
741
742        let inner = FrameTransform {
743            transform: Transform3D::default(),
744            parent_frame: frame_id!("world"),
745            child_frame: frame_id!("robot"),
746        };
747        let mut stf = StampedFrameTransform::new(Some(inner));
748        stf.tov = CuDuration(1000).into();
749
750        assert!(tree.add_transform(&stf).is_ok());
751    }
752
753    #[test]
754    fn test_cyclic_transforms() {
755        let mut tree = TransformTree::<f32>::new();
756
757        let transform1 = make_stamped("world", "robot", 1000.into(), Transform3D::default());
758        let transform2 = make_stamped("robot", "sensor", 1000.into(), Transform3D::default());
759        let transform3 = make_stamped("sensor", "world", 1000.into(), Transform3D::default());
760
761        assert!(tree.add_transform(&transform1).is_ok());
762        assert!(tree.add_transform(&transform2).is_ok());
763
764        let result = tree.add_transform(&transform3);
765        assert!(result.is_err());
766        if let Err(e) = result {
767            assert!(matches!(e, TransformError::CyclicTransformTree));
768        }
769    }
770
771    #[test]
772    fn test_find_path() {
773        let mut tree = TransformTree::<f32>::new();
774
775        let transform1 = make_stamped("world", "robot", 1000.into(), Transform3D::default());
776        let transform2 = make_stamped("robot", "sensor", 1000.into(), Transform3D::default());
777
778        assert!(tree.add_transform(&transform1).is_ok());
779        assert!(tree.add_transform(&transform2).is_ok());
780
781        let path = tree.find_path("world", "sensor");
782        assert!(path.is_ok());
783
784        let path_vec = path.unwrap();
785        assert_eq!(path_vec.len(), 2);
786        assert_eq!(path_vec[0].0.as_str(), "world");
787        assert_eq!(path_vec[0].1.as_str(), "robot");
788        assert_eq!(path_vec[1].0.as_str(), "robot");
789        assert_eq!(path_vec[1].1.as_str(), "sensor");
790    }
791
792    #[test]
793    fn test_lookup_transform_with_inverse() {
794        let mut tree = TransformTree::<f32>::new();
795
796        let matrix = [
797            [1.0, 0.0, 0.0, 0.0],
798            [0.0, 1.0, 0.0, 0.0],
799            [0.0, 0.0, 1.0, 0.0],
800            [2.0, 3.0, 4.0, 1.0],
801        ];
802        let tf = make_stamped(
803            "world",
804            "robot",
805            CuDuration(1000),
806            Transform3D::from_matrix(matrix),
807        );
808
809        assert!(tree.add_transform(&tf).is_ok());
810
811        let clock = RobotClock::default();
812
813        let forward = tree
814            .lookup_transform("world", "robot", CuDuration(1000), &clock)
815            .unwrap();
816        assert_eq!(get_translation(&forward).0, 2.0);
817        assert_eq!(get_translation(&forward).1, 3.0);
818        assert_eq!(get_translation(&forward).2, 4.0);
819
820        let inverse = tree
821            .lookup_transform("robot", "world", CuDuration(1000), &clock)
822            .unwrap();
823        assert_eq!(get_translation(&inverse).0, -2.0);
824        assert_eq!(get_translation(&inverse).1, -3.0);
825        assert_eq!(get_translation(&inverse).2, -4.0);
826    }
827
828    #[test]
829    fn test_multi_step_transform_composition() {
830        let mut tree = TransformTree::<f32>::new();
831        let ts = CuDuration(1000);
832
833        let world_to_base = make_stamped(
834            "world",
835            "base",
836            ts,
837            Transform3D::from_matrix([
838                [1.0, 0.0, 0.0, 0.0],
839                [0.0, 1.0, 0.0, 0.0],
840                [0.0, 0.0, 1.0, 0.0],
841                [1.0, 0.0, 0.0, 1.0],
842            ]),
843        );
844
845        let base_to_arm = make_stamped(
846            "base",
847            "arm",
848            ts,
849            Transform3D::from_matrix([
850                [0.0, 1.0, 0.0, 0.0],
851                [-1.0, 0.0, 0.0, 0.0],
852                [0.0, 0.0, 1.0, 0.0],
853                [0.0, 0.0, 0.0, 1.0],
854            ]),
855        );
856
857        let arm_to_gripper = make_stamped(
858            "arm",
859            "gripper",
860            ts,
861            Transform3D::from_matrix([
862                [1.0, 0.0, 0.0, 0.0],
863                [0.0, 1.0, 0.0, 0.0],
864                [0.0, 0.0, 1.0, 0.0],
865                [0.0, 2.0, 0.0, 1.0],
866            ]),
867        );
868
869        assert!(tree.add_transform(&world_to_base).is_ok());
870        assert!(tree.add_transform(&base_to_arm).is_ok());
871        assert!(tree.add_transform(&arm_to_gripper).is_ok());
872
873        let clock = RobotClock::default();
874        let transform = tree
875            .lookup_transform("world", "gripper", ts, &clock)
876            .unwrap();
877        let epsilon = 1e-5;
878
879        let mat = transform.to_matrix();
880        assert_approx_eq(mat[0][0], 0.0, epsilon, "mat_0_0");
881        assert_approx_eq(mat[1][0], -1.0, epsilon, "mat_1_0");
882        assert_approx_eq(mat[0][1], 1.0, epsilon, "mat_0_1");
883        assert_approx_eq(mat[1][1], 0.0, epsilon, "mat_1_1");
884
885        assert_approx_eq(get_translation(&transform).0, 0.0, epsilon, "translation_x");
886        assert_approx_eq(get_translation(&transform).1, 3.0, epsilon, "translation_y");
887        assert_approx_eq(get_translation(&transform).2, 0.0, epsilon, "translation_z");
888
889        let cached = tree
890            .lookup_transform("world", "gripper", ts, &clock)
891            .unwrap();
892        for i in 0..4 {
893            for j in 0..4 {
894                assert_approx_eq(
895                    transform.to_matrix()[i][j],
896                    cached.to_matrix()[i][j],
897                    epsilon,
898                    "matrix_element",
899                );
900            }
901        }
902
903        let inverse = tree
904            .lookup_transform("gripper", "world", ts, &clock)
905            .unwrap();
906        let inv_mat = inverse.to_matrix();
907        assert_approx_eq(inv_mat[1][0], 1.0, epsilon, "inv_mat_1_0");
908        assert_approx_eq(inv_mat[0][1], -1.0, epsilon, "inv_mat_0_1");
909        assert_approx_eq(get_translation(&inverse).0, -3.0, epsilon, "translation_x");
910        assert_approx_eq(get_translation(&inverse).1, 0.0, epsilon, "translation_y");
911
912        let product = transform * inverse;
913        for i in 0..4 {
914            for j in 0..4 {
915                let expected = if i == j { 1.0 } else { 0.0 };
916                assert_approx_eq(
917                    product.to_matrix()[i][j],
918                    expected,
919                    epsilon,
920                    "matrix_element",
921                );
922            }
923        }
924    }
925
926    #[test]
927    fn test_cache_invalidation() {
928        let mut tree = TransformTree::<f32>::with_cache_settings(5, Duration::from_millis(50));
929        let ts = CuDuration(1000);
930
931        let tf = make_stamped(
932            "a",
933            "b",
934            ts,
935            Transform3D::from_matrix([
936                [1.0, 0.0, 0.0, 0.0],
937                [0.0, 1.0, 0.0, 0.0],
938                [0.0, 0.0, 1.0, 0.0],
939                [1.0, 2.0, 3.0, 1.0],
940            ]),
941        );
942
943        assert!(tree.add_transform(&tf).is_ok());
944
945        let clock = RobotClock::default();
946        let result1 = tree.lookup_transform("a", "b", ts, &clock);
947        assert!(result1.is_ok());
948
949        let transform1 = result1.unwrap();
950        assert_eq!(get_translation(&transform1).0, 1.0);
951
952        std::thread::sleep(Duration::from_millis(100));
953
954        let result2 = tree.lookup_transform("a", "b", ts, &clock);
955        assert!(result2.is_ok());
956
957        tree.clear_cache();
958
959        let result3 = tree.lookup_transform("a", "b", ts, &clock);
960        assert!(result3.is_ok());
961    }
962
963    #[test]
964    fn test_multi_step_transform_with_inverse() {
965        let mut tree = TransformTree::<f32>::new();
966        let ts = CuDuration(1000);
967
968        let world_to_robot = make_stamped(
969            "world",
970            "robot",
971            ts,
972            Transform3D::from_matrix([
973                [1.0, 0.0, 0.0, 0.0],
974                [0.0, 1.0, 0.0, 0.0],
975                [0.0, 0.0, 1.0, 0.0],
976                [1.0, 2.0, 3.0, 1.0],
977            ]),
978        );
979
980        let robot_to_camera = make_stamped(
981            "robot",
982            "camera",
983            ts,
984            Transform3D::from_matrix([
985                [0.0, 1.0, 0.0, 0.0],
986                [-1.0, 0.0, 0.0, 0.0],
987                [0.0, 0.0, 1.0, 0.0],
988                [0.5, 0.0, 0.2, 1.0],
989            ]),
990        );
991
992        assert!(tree.add_transform(&world_to_robot).is_ok());
993        assert!(tree.add_transform(&robot_to_camera).is_ok());
994
995        let clock = RobotClock::default();
996        let transform = tree
997            .lookup_transform("world", "camera", ts, &clock)
998            .unwrap();
999        let epsilon = 1e-5;
1000
1001        let mat = transform.to_matrix();
1002        assert_approx_eq(mat[0][0], 0.0, epsilon, "mat_0_0");
1003        assert_approx_eq(mat[1][0], -1.0, epsilon, "mat_1_0");
1004        assert_approx_eq(mat[0][1], 1.0, epsilon, "mat_0_1");
1005        assert_approx_eq(mat[1][1], 0.0, epsilon, "mat_1_1");
1006
1007        assert_approx_eq(
1008            get_translation(&transform).0,
1009            -1.5,
1010            epsilon,
1011            "translation_x",
1012        );
1013        assert_approx_eq(get_translation(&transform).1, 1.0, epsilon, "translation_y");
1014        assert_approx_eq(get_translation(&transform).2, 3.2, epsilon, "translation_z");
1015
1016        let inverse = tree.lookup_transform("camera", "world", ts, &clock);
1017        assert!(inverse.is_ok());
1018    }
1019
1020    #[test]
1021    fn test_cache_cleanup() {
1022        let tree = TransformTree::<f32>::with_cache_settings(5, Duration::from_millis(10));
1023
1024        // Explicitly trigger cache cleanup
1025        let clock = RobotClock::default();
1026        tree.cleanup_cache(&clock);
1027    }
1028
1029    #[test]
1030    fn test_lookup_velocity() {
1031        let mut tree = TransformTree::<f32>::new();
1032
1033        let w2b_1 = make_stamped(
1034            "world",
1035            "base",
1036            CuDuration(1_000_000_000),
1037            Transform3D::from_matrix([
1038                [1.0, 0.0, 0.0, 0.0],
1039                [0.0, 1.0, 0.0, 0.0],
1040                [0.0, 0.0, 1.0, 0.0],
1041                [0.0, 0.0, 0.0, 1.0],
1042            ]),
1043        );
1044
1045        let w2b_2 = make_stamped(
1046            "world",
1047            "base",
1048            CuDuration(2_000_000_000),
1049            Transform3D::from_matrix([
1050                [1.0, 0.0, 0.0, 0.0],
1051                [0.0, 1.0, 0.0, 0.0],
1052                [0.0, 0.0, 1.0, 0.0],
1053                [1.0, 0.0, 0.0, 1.0],
1054            ]),
1055        );
1056
1057        let b2s_1 = make_stamped(
1058            "base",
1059            "sensor",
1060            CuDuration(1_000_000_000),
1061            Transform3D::from_matrix([
1062                [1.0, 0.0, 0.0, 0.0],
1063                [0.0, 1.0, 0.0, 0.0],
1064                [0.0, 0.0, 1.0, 0.0],
1065                [0.0, 0.0, 0.0, 1.0],
1066            ]),
1067        );
1068
1069        let b2s_2 = make_stamped(
1070            "base",
1071            "sensor",
1072            CuDuration(2_000_000_000),
1073            Transform3D::from_matrix([
1074                [1.0, 0.0, 0.0, 0.0],
1075                [0.0, 1.0, 0.0, 0.0],
1076                [0.0, 0.0, 1.0, 0.0],
1077                [0.0, 2.0, 0.0, 1.0],
1078            ]),
1079        );
1080
1081        tree.add_transform(&w2b_1).unwrap();
1082        tree.add_transform(&w2b_2).unwrap();
1083        tree.add_transform(&b2s_1).unwrap();
1084        tree.add_transform(&b2s_2).unwrap();
1085
1086        let clock = RobotClock::default();
1087        let velocity = tree.lookup_velocity("world", "sensor", CuDuration(1_500_000_000), &clock);
1088        assert!(velocity.is_ok());
1089
1090        let vel = velocity.unwrap();
1091        let epsilon = 0.1;
1092        assert_approx_eq(vel.linear[0], 1.0, epsilon, "linear_velocity_0");
1093        assert_approx_eq(vel.linear[1], 2.0, epsilon, "linear_velocity_1");
1094        assert_approx_eq(vel.linear[2], 0.0, epsilon, "linear_velocity_2");
1095    }
1096
1097    #[test]
1098    fn test_velocity_with_rotation() {
1099        let mut tree = TransformTree::<f32>::new();
1100
1101        let ts1 = CuDuration(1_000_000_000);
1102        let ts2 = CuDuration(2_000_000_000);
1103
1104        let w2b_1 = make_stamped(
1105            "world",
1106            "base",
1107            ts1,
1108            Transform3D::from_matrix([
1109                [1.0, 0.0, 0.0, 0.0],
1110                [0.0, 1.0, 0.0, 0.0],
1111                [0.0, 0.0, 1.0, 0.0],
1112                [0.0, 0.0, 0.0, 1.0],
1113            ]),
1114        );
1115
1116        let b2s_1 = make_stamped(
1117            "base",
1118            "sensor",
1119            ts1,
1120            Transform3D::from_matrix([
1121                [0.0, 1.0, 0.0, 0.0],
1122                [-1.0, 0.0, 0.0, 0.0],
1123                [0.0, 0.0, 1.0, 0.0],
1124                [1.0, 0.0, 0.0, 1.0],
1125            ]),
1126        );
1127
1128        let w2b_2 = make_stamped(
1129            "world",
1130            "base",
1131            ts2,
1132            Transform3D::from_matrix([
1133                [1.0, 0.0, 0.0, 0.0],
1134                [0.0, 1.0, 0.0, 0.0],
1135                [0.0, 0.0, 1.0, 0.0],
1136                [1.0, 0.0, 0.0, 1.0],
1137            ]),
1138        );
1139
1140        let b2s_2 = make_stamped(
1141            "base",
1142            "sensor",
1143            ts2,
1144            Transform3D::from_matrix([
1145                [0.0, 1.0, 0.0, 0.0],
1146                [-1.0, 0.0, 0.0, 0.0],
1147                [0.0, 0.0, 1.0, 0.0],
1148                [1.0, 0.0, 0.0, 1.0],
1149            ]),
1150        );
1151
1152        tree.add_transform(&w2b_1).unwrap();
1153        tree.add_transform(&w2b_2).unwrap();
1154        tree.add_transform(&b2s_1).unwrap();
1155        tree.add_transform(&b2s_2).unwrap();
1156
1157        let clock = RobotClock::default();
1158        let mid_ts = CuDuration(1_500_000_000);
1159
1160        let velocity = tree.lookup_velocity("world", "sensor", mid_ts, &clock);
1161        assert!(velocity.is_ok());
1162        let vel = velocity.unwrap();
1163        let epsilon = 0.2;
1164        assert_approx_eq(vel.linear[0], 1.0, epsilon, "linear_velocity_0");
1165        assert_approx_eq(vel.linear[1], 0.0, epsilon, "linear_velocity_1");
1166        assert_approx_eq(vel.linear[2], 0.0, epsilon, "linear_velocity_2");
1167
1168        let reverse = tree.lookup_velocity("sensor", "world", mid_ts, &clock);
1169        assert!(reverse.is_ok());
1170        let rev_vel = reverse.unwrap();
1171        assert_approx_eq(rev_vel.linear[0], -1.0, epsilon, "linear_velocity_0");
1172        assert_approx_eq(rev_vel.linear[1], 0.0, epsilon, "linear_velocity_1");
1173        assert_approx_eq(rev_vel.linear[2], 0.0, epsilon, "linear_velocity_2");
1174    }
1175
1176    #[test]
1177    fn test_velocity_with_angular_motion() {
1178        let mut tree = TransformTree::<f32>::new();
1179        let ts1 = CuDuration(1_000_000_000);
1180        let ts2 = CuDuration(2_000_000_000);
1181
1182        let w2b_1 = make_stamped(
1183            "world",
1184            "base",
1185            ts1,
1186            Transform3D::from_matrix([
1187                [1.0, 0.0, 0.0, 0.0],
1188                [0.0, 1.0, 0.0, 0.0],
1189                [0.0, 0.0, 1.0, 0.0],
1190                [0.0, 0.0, 0.0, 1.0],
1191            ]),
1192        );
1193
1194        let w2b_2 = make_stamped(
1195            "world",
1196            "base",
1197            ts2,
1198            Transform3D::from_matrix([
1199                [0.0, 1.0, 0.0, 0.0],
1200                [-1.0, 0.0, 0.0, 0.0],
1201                [0.0, 0.0, 1.0, 0.0],
1202                [0.0, 0.0, 0.0, 1.0],
1203            ]),
1204        );
1205
1206        let b2s_1 = make_stamped(
1207            "base",
1208            "sensor",
1209            ts1,
1210            Transform3D::from_matrix([
1211                [1.0, 0.0, 0.0, 0.0],
1212                [0.0, 1.0, 0.0, 0.0],
1213                [0.0, 0.0, 1.0, 0.0],
1214                [1.0, 0.0, 0.0, 1.0],
1215            ]),
1216        );
1217
1218        let b2s_2 = make_stamped(
1219            "base",
1220            "sensor",
1221            ts2,
1222            Transform3D::from_matrix([
1223                [1.0, 0.0, 0.0, 0.0],
1224                [0.0, 1.0, 0.0, 0.0],
1225                [0.0, 0.0, 1.0, 0.0],
1226                [1.0, 0.0, 0.0, 1.0],
1227            ]),
1228        );
1229
1230        tree.add_transform(&w2b_1).unwrap();
1231        tree.add_transform(&w2b_2).unwrap();
1232        tree.add_transform(&b2s_1).unwrap();
1233        tree.add_transform(&b2s_2).unwrap();
1234
1235        let clock = RobotClock::default();
1236        let vel = tree
1237            .lookup_velocity("world", "sensor", CuDuration(1_500_000_000), &clock)
1238            .unwrap();
1239
1240        let epsilon = 0.1;
1241        assert_approx_eq(vel.angular[0], 0.0, epsilon, "angular_velocity_0");
1242        assert_approx_eq(vel.angular[1], 0.0, epsilon, "angular_velocity_1");
1243        assert_approx_eq(vel.angular[2], -1.0, epsilon, "angular_velocity_2");
1244
1245        assert!(!vel.linear[0].is_nan());
1246        assert!(!vel.linear[1].is_nan());
1247        assert!(!vel.linear[2].is_nan());
1248
1249        assert!(!vel.angular[0].is_nan());
1250        assert!(!vel.angular[1].is_nan());
1251        assert!(!vel.angular[2].is_nan());
1252    }
1253
1254    #[test]
1255    fn test_velocity_cache() {
1256        let mut tree = TransformTree::<f32>::new();
1257        let ts1 = CuDuration(1_000_000_000);
1258        let ts2 = CuDuration(2_000_000_000);
1259
1260        let tf1 = make_stamped(
1261            "world",
1262            "robot",
1263            ts1,
1264            Transform3D::from_matrix([
1265                [1.0, 0.0, 0.0, 0.0],
1266                [0.0, 1.0, 0.0, 0.0],
1267                [0.0, 0.0, 1.0, 0.0],
1268                [0.0, 0.0, 0.0, 1.0],
1269            ]),
1270        );
1271
1272        let tf2 = make_stamped(
1273            "world",
1274            "robot",
1275            ts2,
1276            Transform3D::from_matrix([
1277                [1.0, 0.0, 0.0, 0.0],
1278                [0.0, 1.0, 0.0, 0.0],
1279                [0.0, 0.0, 1.0, 0.0],
1280                [2.0, 0.0, 0.0, 1.0],
1281            ]),
1282        );
1283
1284        tree.add_transform(&tf1).unwrap();
1285        tree.add_transform(&tf2).unwrap();
1286
1287        let clock = RobotClock::default();
1288
1289        let start_time = std::time::Instant::now();
1290        let velocity1 = tree.lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock);
1291        let first_lookup_time = start_time.elapsed();
1292
1293        assert!(velocity1.is_ok());
1294        let vel1 = velocity1.unwrap();
1295        assert_approx_eq(vel1.linear[0], 2.0, 0.01, "linear_velocity_0");
1296
1297        let start_time = std::time::Instant::now();
1298        let velocity2 = tree.lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock);
1299        let second_lookup_time = start_time.elapsed();
1300
1301        assert!(velocity2.is_ok());
1302        let vel2 = velocity2.unwrap();
1303        assert_approx_eq(vel2.linear[0], 2.0, 0.01, "linear_velocity_0");
1304
1305        tree.clear_cache();
1306
1307        let start_time = std::time::Instant::now();
1308        let velocity3 = tree.lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock);
1309        let third_lookup_time = start_time.elapsed();
1310
1311        assert!(velocity3.is_ok());
1312
1313        println!("First lookup: {first_lookup_time:?}");
1314        println!("Second lookup (cached): {second_lookup_time:?}");
1315        println!("Third lookup (after cache clear): {third_lookup_time:?}");
1316    }
1317
1318    #[test]
1319    fn test_velocity_cache_invalidation() {
1320        let mut tree = TransformTree::<f32>::new();
1321        let ts1 = CuDuration(1_000_000_000);
1322        let ts2 = CuDuration(2_000_000_000);
1323        let ts3 = CuDuration(3_000_000_000);
1324
1325        let tf1 = make_stamped(
1326            "world",
1327            "robot",
1328            ts1,
1329            Transform3D::from_matrix([
1330                [1.0, 0.0, 0.0, 0.0],
1331                [0.0, 1.0, 0.0, 0.0],
1332                [0.0, 0.0, 1.0, 0.0],
1333                [0.0, 0.0, 0.0, 1.0],
1334            ]),
1335        );
1336
1337        let tf2 = make_stamped(
1338            "world",
1339            "robot",
1340            ts2,
1341            Transform3D::from_matrix([
1342                [1.0, 0.0, 0.0, 0.0],
1343                [0.0, 1.0, 0.0, 0.0],
1344                [0.0, 0.0, 1.0, 0.0],
1345                [1.0, 0.0, 0.0, 1.0],
1346            ]),
1347        );
1348
1349        let tf3 = make_stamped(
1350            "world",
1351            "robot",
1352            ts3,
1353            Transform3D::from_matrix([
1354                [1.0, 0.0, 0.0, 0.0],
1355                [0.0, 1.0, 0.0, 0.0],
1356                [0.0, 0.0, 1.0, 0.0],
1357                [3.0, 0.0, 0.0, 1.0],
1358            ]),
1359        );
1360
1361        tree.add_transform(&tf1).unwrap();
1362        tree.add_transform(&tf2).unwrap();
1363
1364        let clock = RobotClock::default();
1365        let velocity1 = tree
1366            .lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock)
1367            .unwrap();
1368        assert_approx_eq(velocity1.linear[0], 1.0, 0.01, "linear_velocity_0");
1369
1370        tree.add_transform(&tf3).unwrap();
1371
1372        let velocity2 = tree
1373            .lookup_velocity("world", "robot", CuDuration(1_500_000_000), &clock)
1374            .unwrap();
1375        assert_approx_eq(velocity2.linear[0], 1.0, 0.01, "linear_velocity_0");
1376
1377        let velocity3 = tree
1378            .lookup_velocity("world", "robot", CuDuration(2_500_000_000), &clock)
1379            .unwrap();
1380        assert_approx_eq(velocity3.linear[0], 2.0, 0.01, "linear_velocity_0");
1381    }
1382}