rag_plusplus_core/distance/
mod.rs

1//! Distance Computation Module
2//!
3//! Provides optimized distance/similarity functions for vector operations.
4
5#![allow(unsafe_code)]  // SIMD dispatch requires unsafe for intrinsics
6//!
7//! # Architecture
8//!
9//! This module is designed for SIMD acceleration:
10//! - `scalar`: Pure Rust baseline implementations
11//! - `simd_avx2`: AVX2-accelerated (x86_64) - conditionally compiled
12//! - `simd_neon`: NEON-accelerated (ARM64) - future
13//!
14//! Runtime dispatch automatically selects the fastest available implementation
15//! based on CPU feature detection.
16//!
17//! # Distance Types
18//!
19//! - **L2 (Euclidean)**: `sqrt(sum((a[i] - b[i])^2))` - lower is more similar
20//! - **Inner Product**: `sum(a[i] * b[i])` - higher is more similar
21//! - **Cosine**: `dot(a, b) / (norm(a) * norm(b))` - higher is more similar
22//!
23//! # Performance Notes
24//!
25//! - AVX2 provides 4-8x speedup on x86_64 CPUs (Intel Haswell+, AMD Zen+)
26//! - For cosine similarity on pre-normalized vectors, use inner product instead
27//!   (equivalent result, avoids redundant normalization)
28
29pub mod scalar;
30
31#[cfg(target_arch = "x86_64")]
32pub mod simd_avx2;
33
34// Re-export scalar functions as the baseline API
35pub use scalar::{
36    l2_distance,
37    l2_distance_squared,
38    inner_product,
39    cosine_similarity,
40    cosine_distance,
41    norm,
42    norm_squared,
43    normalize,
44    normalize_in_place,
45    // Batch operations
46    normalize_batch,
47    normalize_batch_flat,
48    compute_norms_batch,
49    find_unnormalized,
50};
51
52use crate::index::DistanceType;
53
54// =============================================================================
55// RUNTIME SIMD DISPATCH
56// =============================================================================
57
58/// Check if AVX2 + FMA is available at runtime.
59#[cfg(target_arch = "x86_64")]
60#[inline]
61fn has_avx2_fma() -> bool {
62    is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma")
63}
64
65/// L2 distance with automatic SIMD dispatch.
66///
67/// Selects the fastest implementation based on CPU features:
68/// - AVX2+FMA on x86_64 (if available)
69/// - Scalar fallback otherwise
70#[inline]
71pub fn l2_distance_fast(a: &[f32], b: &[f32]) -> f32 {
72    #[cfg(target_arch = "x86_64")]
73    {
74        if has_avx2_fma() {
75            // SAFETY: CPU feature detection guarantees AVX2+FMA support
76            return unsafe { simd_avx2::l2_distance_avx2(a, b) };
77        }
78    }
79    scalar::l2_distance(a, b)
80}
81
82/// Inner product with automatic SIMD dispatch.
83#[inline]
84pub fn inner_product_fast(a: &[f32], b: &[f32]) -> f32 {
85    #[cfg(target_arch = "x86_64")]
86    {
87        if has_avx2_fma() {
88            return unsafe { simd_avx2::inner_product_avx2(a, b) };
89        }
90    }
91    scalar::inner_product(a, b)
92}
93
94/// Cosine similarity with automatic SIMD dispatch.
95#[inline]
96pub fn cosine_similarity_fast(a: &[f32], b: &[f32]) -> f32 {
97    #[cfg(target_arch = "x86_64")]
98    {
99        if has_avx2_fma() {
100            return unsafe { simd_avx2::cosine_similarity_avx2(a, b) };
101        }
102    }
103    scalar::cosine_similarity(a, b)
104}
105
106/// Cosine distance with automatic SIMD dispatch.
107#[inline]
108pub fn cosine_distance_fast(a: &[f32], b: &[f32]) -> f32 {
109    1.0 - cosine_similarity_fast(a, b)
110}
111
112/// Norm with automatic SIMD dispatch.
113#[inline]
114pub fn norm_fast(a: &[f32]) -> f32 {
115    #[cfg(target_arch = "x86_64")]
116    {
117        if has_avx2_fma() {
118            return unsafe { simd_avx2::norm_avx2(a) };
119        }
120    }
121    scalar::norm(a)
122}
123
124/// Normalize in-place with automatic SIMD dispatch.
125#[inline]
126pub fn normalize_in_place_fast(a: &mut [f32]) -> f32 {
127    #[cfg(target_arch = "x86_64")]
128    {
129        if has_avx2_fma() {
130            return unsafe { simd_avx2::normalize_in_place_avx2(a) };
131        }
132    }
133    scalar::normalize_in_place(a)
134}
135
136/// Compute distance between two vectors based on distance type.
137///
138/// This is the primary dispatch function used by indexes.
139/// Automatically uses SIMD acceleration when available.
140///
141/// # Arguments
142///
143/// * `a` - First vector
144/// * `b` - Second vector
145/// * `distance_type` - Type of distance metric
146///
147/// # Returns
148///
149/// Distance value. Interpretation depends on distance type:
150/// - L2: Lower is more similar (0 = identical)
151/// - InnerProduct: Higher is more similar
152/// - Cosine: Higher is more similar (range -1 to 1)
153///
154/// # Panics
155///
156/// Panics if vectors have different lengths.
157#[inline]
158pub fn compute_distance(a: &[f32], b: &[f32], distance_type: DistanceType) -> f32 {
159    debug_assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
160
161    match distance_type {
162        DistanceType::L2 => l2_distance_fast(a, b),
163        DistanceType::InnerProduct => inner_product_fast(a, b),
164        DistanceType::Cosine => cosine_similarity_fast(a, b),
165    }
166}
167
168/// Compute distance for heap-based search (lower = better for all types).
169///
170/// This function returns values suitable for min-heap based nearest neighbor search.
171/// For similarity metrics (IP, Cosine), the sign/value is adjusted so that
172/// lower values indicate more similar vectors.
173/// Automatically uses SIMD acceleration when available.
174///
175/// # Arguments
176///
177/// * `a` - First vector
178/// * `b` - Second vector
179/// * `distance_type` - Type of distance metric
180///
181/// # Returns
182///
183/// Heap-compatible distance where lower = more similar.
184#[inline]
185pub fn compute_distance_for_heap(a: &[f32], b: &[f32], distance_type: DistanceType) -> f32 {
186    debug_assert_eq!(a.len(), b.len(), "Vector dimension mismatch");
187
188    match distance_type {
189        DistanceType::L2 => l2_distance_fast(a, b),
190        DistanceType::InnerProduct => -inner_product_fast(a, b), // Negate: higher IP = lower heap value
191        DistanceType::Cosine => cosine_distance_fast(a, b),       // 1 - cos: higher cos = lower distance
192    }
193}
194
195/// Check if a distance type uses similarity (higher = more similar).
196#[inline]
197#[must_use]
198pub const fn is_similarity_metric(distance_type: DistanceType) -> bool {
199    matches!(distance_type, DistanceType::InnerProduct | DistanceType::Cosine)
200}
201
202// =============================================================================
203// BATCH OPERATIONS WITH SIMD DISPATCH
204// =============================================================================
205
206/// Parallel batch L2 normalization using SIMD and multi-threading.
207///
208/// This is the high-performance version for processing thousands of vectors.
209/// Uses rayon for parallelism and SIMD for per-vector operations.
210///
211/// # Arguments
212///
213/// * `data` - Flat array of vectors stored contiguously
214/// * `dim` - Dimension of each vector
215///
216/// # Returns
217///
218/// Vector of original norms for each vector.
219///
220/// # Performance
221///
222/// - Uses rayon for parallel processing across CPU cores
223/// - Uses SIMD (AVX2 on x86_64) for individual vector operations
224/// - Optimal for batches of 1000+ vectors
225#[cfg(feature = "parallel")]
226pub fn normalize_batch_parallel(data: &mut [f32], dim: usize) -> Vec<f32> {
227    use rayon::prelude::*;
228
229    assert!(dim > 0, "Dimension must be > 0");
230    assert!(data.len() % dim == 0, "Data length must be multiple of dimension");
231
232    // Split into chunks and process in parallel using rayon
233    data.par_chunks_mut(dim)
234        .map(|vector| normalize_in_place_fast(vector))
235        .collect()
236}
237
238/// Non-parallel version for when rayon is not available.
239#[cfg(not(feature = "parallel"))]
240pub fn normalize_batch_parallel(data: &mut [f32], dim: usize) -> Vec<f32> {
241    normalize_batch_flat_fast(data, dim)
242}
243
244/// Batch L2 normalization with SIMD dispatch (single-threaded).
245///
246/// # Arguments
247///
248/// * `data` - Flat array of vectors stored contiguously
249/// * `dim` - Dimension of each vector
250///
251/// # Returns
252///
253/// Vector of original norms for each vector.
254pub fn normalize_batch_flat_fast(data: &mut [f32], dim: usize) -> Vec<f32> {
255    assert!(dim > 0, "Dimension must be > 0");
256    assert!(data.len() % dim == 0, "Data length must be multiple of dimension");
257
258    let n_vectors = data.len() / dim;
259    let mut norms = Vec::with_capacity(n_vectors);
260
261    for i in 0..n_vectors {
262        let start = i * dim;
263        let end = start + dim;
264        let vector = &mut data[start..end];
265        let n = normalize_in_place_fast(vector);
266        norms.push(n);
267    }
268
269    norms
270}
271
272/// Compute norms for a batch of vectors with SIMD dispatch.
273pub fn compute_norms_batch_fast(data: &[f32], dim: usize) -> Vec<f32> {
274    assert!(dim > 0, "Dimension must be > 0");
275    assert!(data.len() % dim == 0, "Data length must be multiple of dimension");
276
277    let n_vectors = data.len() / dim;
278    let mut norms = Vec::with_capacity(n_vectors);
279
280    for i in 0..n_vectors {
281        let start = i * dim;
282        let end = start + dim;
283        let vector = &data[start..end];
284        norms.push(norm_fast(vector));
285    }
286
287    norms
288}
289
290// =============================================================================
291// TRAJECTORY-WEIGHTED DISTANCE (TPO Integration)
292// =============================================================================
293
294use crate::trajectory::{TrajectoryCoordinate, TrajectoryCoordinate5D};
295
296/// Compute trajectory-weighted cosine similarity.
297///
298/// Combines semantic similarity (cosine of embeddings) with spatial similarity
299/// (distance in trajectory coordinate space). This enables retrieval that
300/// considers both content and structural position.
301///
302/// # Formula
303///
304/// ```text
305/// weighted_sim = (1 - coord_weight) * cosine(a, b) + coord_weight * (1 - coord_dist)
306/// ```
307///
308/// # Arguments
309///
310/// * `a` - First embedding vector
311/// * `b` - Second embedding vector
312/// * `coord_a` - Trajectory coordinate of first episode
313/// * `coord_b` - Trajectory coordinate of second episode
314/// * `coord_weight` - Weight for coordinate component [0, 1]
315///   - 0.0 = pure semantic similarity (ignore coordinates)
316///   - 0.5 = equal weight to semantic and spatial
317///   - 1.0 = pure spatial similarity (ignore content)
318///
319/// # Returns
320///
321/// Combined similarity score. Range depends on inputs but typically [0, 1].
322///
323/// # Example
324///
325/// ```
326/// use rag_plusplus_core::distance::trajectory_weighted_cosine;
327/// use rag_plusplus_core::trajectory::TrajectoryCoordinate;
328///
329/// let emb_a = vec![1.0, 0.0, 0.0];
330/// let emb_b = vec![0.9, 0.436, 0.0]; // Similar direction
331///
332/// let coord_a = TrajectoryCoordinate::new(1, 0, 0.9, 0.2);
333/// let coord_b = TrajectoryCoordinate::new(2, 0, 0.8, 0.3); // Close in trajectory
334///
335/// let sim = trajectory_weighted_cosine(&emb_a, &emb_b, &coord_a, &coord_b, 0.3);
336/// assert!(sim > 0.8); // High due to both semantic and spatial similarity
337/// ```
338#[inline]
339pub fn trajectory_weighted_cosine(
340    a: &[f32],
341    b: &[f32],
342    coord_a: &TrajectoryCoordinate,
343    coord_b: &TrajectoryCoordinate,
344    coord_weight: f32,
345) -> f32 {
346    let coord_weight = coord_weight.clamp(0.0, 1.0);
347
348    let semantic_sim = cosine_similarity_fast(a, b);
349    let coord_dist = coord_a.distance(coord_b);
350
351    // Normalize coord_dist to [0, 1] range (max distance is ~4.0 for 4D coords)
352    // Using a typical max distance of 4.0 for normalization
353    let coord_sim = (1.0 - coord_dist / 4.0).clamp(0.0, 1.0);
354
355    (1.0 - coord_weight) * semantic_sim + coord_weight * coord_sim
356}
357
358/// Compute trajectory-weighted cosine similarity with 5D coordinates.
359///
360/// Same as [`trajectory_weighted_cosine`] but uses 5D coordinates that include
361/// the complexity dimension from TPO.
362#[inline]
363pub fn trajectory_weighted_cosine_5d(
364    a: &[f32],
365    b: &[f32],
366    coord_a: &TrajectoryCoordinate5D,
367    coord_b: &TrajectoryCoordinate5D,
368    coord_weight: f32,
369) -> f32 {
370    let coord_weight = coord_weight.clamp(0.0, 1.0);
371
372    let semantic_sim = cosine_similarity_fast(a, b);
373    let coord_dist = coord_a.distance(coord_b);
374
375    // Normalize coord_dist to [0, 1] range (max distance is ~4.5 for 5D coords)
376    let coord_sim = (1.0 - coord_dist / 4.5).clamp(0.0, 1.0);
377
378    (1.0 - coord_weight) * semantic_sim + coord_weight * coord_sim
379}
380
381/// Compute trajectory-weighted L2 distance.
382///
383/// Combines L2 distance with trajectory coordinate distance.
384/// Lower values indicate more similar episodes.
385#[inline]
386pub fn trajectory_weighted_l2(
387    a: &[f32],
388    b: &[f32],
389    coord_a: &TrajectoryCoordinate,
390    coord_b: &TrajectoryCoordinate,
391    coord_weight: f32,
392) -> f32 {
393    let coord_weight = coord_weight.clamp(0.0, 1.0);
394
395    let semantic_dist = l2_distance_fast(a, b);
396    let coord_dist = coord_a.distance(coord_b);
397
398    (1.0 - coord_weight) * semantic_dist + coord_weight * coord_dist
399}
400
401/// Compute trajectory-weighted inner product.
402///
403/// Combines inner product with trajectory coordinate similarity.
404/// Higher values indicate more similar episodes.
405#[inline]
406pub fn trajectory_weighted_inner_product(
407    a: &[f32],
408    b: &[f32],
409    coord_a: &TrajectoryCoordinate,
410    coord_b: &TrajectoryCoordinate,
411    coord_weight: f32,
412) -> f32 {
413    let coord_weight = coord_weight.clamp(0.0, 1.0);
414
415    let semantic_sim = inner_product_fast(a, b);
416    let coord_dist = coord_a.distance(coord_b);
417
418    // For inner product, we need to convert coord_dist to a similarity-like value
419    let coord_sim = (4.0 - coord_dist).max(0.0); // Higher = more similar
420
421    (1.0 - coord_weight) * semantic_sim + coord_weight * coord_sim
422}
423
424/// Compute trajectory-weighted distance with configurable distance type.
425///
426/// # Arguments
427///
428/// * `a` - First embedding vector
429/// * `b` - Second embedding vector
430/// * `coord_a` - Trajectory coordinate of first episode
431/// * `coord_b` - Trajectory coordinate of second episode
432/// * `distance_type` - Type of semantic distance to use
433/// * `coord_weight` - Weight for coordinate component [0, 1]
434///
435/// # Returns
436///
437/// Combined distance/similarity. Interpretation depends on distance_type.
438#[inline]
439pub fn trajectory_weighted_distance(
440    a: &[f32],
441    b: &[f32],
442    coord_a: &TrajectoryCoordinate,
443    coord_b: &TrajectoryCoordinate,
444    distance_type: DistanceType,
445    coord_weight: f32,
446) -> f32 {
447    match distance_type {
448        DistanceType::L2 => trajectory_weighted_l2(a, b, coord_a, coord_b, coord_weight),
449        DistanceType::InnerProduct => trajectory_weighted_inner_product(a, b, coord_a, coord_b, coord_weight),
450        DistanceType::Cosine => trajectory_weighted_cosine(a, b, coord_a, coord_b, coord_weight),
451    }
452}
453
454/// Configuration for trajectory-weighted distance computation.
455#[derive(Debug, Clone, Copy)]
456pub struct TrajectoryDistanceConfig {
457    /// Weight for coordinate component [0, 1]
458    pub coord_weight: f32,
459    /// Base distance type
460    pub distance_type: DistanceType,
461    /// Whether to boost similarity for same-phase episodes
462    pub phase_boost: bool,
463    /// Boost amount for same-phase episodes
464    pub phase_boost_amount: f32,
465}
466
467impl Default for TrajectoryDistanceConfig {
468    fn default() -> Self {
469        Self {
470            coord_weight: 0.2,       // Mostly semantic, some trajectory context
471            distance_type: DistanceType::Cosine,
472            phase_boost: true,
473            phase_boost_amount: 0.1, // 10% boost for same phase
474        }
475    }
476}
477
478impl TrajectoryDistanceConfig {
479    /// Create config for pure semantic distance (no trajectory weighting).
480    pub fn semantic_only() -> Self {
481        Self {
482            coord_weight: 0.0,
483            ..Default::default()
484        }
485    }
486
487    /// Create config with equal semantic and trajectory weight.
488    pub fn balanced() -> Self {
489        Self {
490            coord_weight: 0.5,
491            ..Default::default()
492        }
493    }
494
495    /// Create config emphasizing trajectory structure.
496    pub fn trajectory_focused() -> Self {
497        Self {
498            coord_weight: 0.7,
499            ..Default::default()
500        }
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    #[test]
509    fn test_compute_distance_l2() {
510        let a = [1.0, 0.0, 0.0, 0.0];
511        let b = [0.0, 1.0, 0.0, 0.0];
512
513        let dist = compute_distance(&a, &b, DistanceType::L2);
514        assert!((dist - std::f32::consts::SQRT_2).abs() < 1e-6);
515    }
516
517    #[test]
518    fn test_compute_distance_inner_product() {
519        let a = [1.0, 2.0, 3.0];
520        let b = [4.0, 5.0, 6.0];
521
522        let dist = compute_distance(&a, &b, DistanceType::InnerProduct);
523        assert!((dist - 32.0).abs() < 1e-6); // 1*4 + 2*5 + 3*6 = 32
524    }
525
526    #[test]
527    fn test_compute_distance_cosine() {
528        let a = [1.0, 0.0];
529        let b = [1.0, 0.0];
530
531        let dist = compute_distance(&a, &b, DistanceType::Cosine);
532        assert!((dist - 1.0).abs() < 1e-6); // Identical = cosine 1.0
533    }
534
535    #[test]
536    fn test_heap_distance_ordering() {
537        let a = [1.0, 0.0, 0.0];
538        let b_close = [0.9, 0.1, 0.0];
539        let b_far = [0.0, 1.0, 0.0];
540
541        // For all metrics, closer vector should have lower heap distance
542        for dt in [DistanceType::L2, DistanceType::InnerProduct, DistanceType::Cosine] {
543            let d_close = compute_distance_for_heap(&a, &b_close, dt);
544            let d_far = compute_distance_for_heap(&a, &b_far, dt);
545
546            // Note: This test may not hold for all vector combinations
547            // The key invariant is consistent ordering within each metric
548        }
549    }
550
551    #[test]
552    fn test_is_similarity_metric() {
553        assert!(!is_similarity_metric(DistanceType::L2));
554        assert!(is_similarity_metric(DistanceType::InnerProduct));
555        assert!(is_similarity_metric(DistanceType::Cosine));
556    }
557
558    // =========================================================================
559    // TRAJECTORY-WEIGHTED DISTANCE TESTS
560    // =========================================================================
561
562    #[test]
563    fn test_trajectory_weighted_cosine_pure_semantic() {
564        use crate::trajectory::TrajectoryCoordinate;
565
566        let a = [1.0, 0.0, 0.0];
567        let b = [1.0, 0.0, 0.0];
568
569        let coord_a = TrajectoryCoordinate::new(0, 0, 1.0, 0.0);
570        let coord_b = TrajectoryCoordinate::new(5, 3, 0.2, 1.0); // Very different coords
571
572        // With coord_weight = 0, should be pure cosine
573        let sim = trajectory_weighted_cosine(&a, &b, &coord_a, &coord_b, 0.0);
574        assert!((sim - 1.0).abs() < 1e-6); // Identical embeddings
575    }
576
577    #[test]
578    fn test_trajectory_weighted_cosine_pure_spatial() {
579        use crate::trajectory::TrajectoryCoordinate;
580
581        let a = [1.0, 0.0, 0.0];
582        let b = [0.0, 1.0, 0.0]; // Orthogonal
583
584        let coord_a = TrajectoryCoordinate::new(1, 0, 0.9, 0.5);
585        let coord_b = TrajectoryCoordinate::new(1, 0, 0.9, 0.5); // Identical coords
586
587        // With coord_weight = 1, should be pure spatial (coords identical = max similarity)
588        let sim = trajectory_weighted_cosine(&a, &b, &coord_a, &coord_b, 1.0);
589        assert!((sim - 1.0).abs() < 1e-6); // Identical coordinates
590    }
591
592    #[test]
593    fn test_trajectory_weighted_cosine_mixed() {
594        use crate::trajectory::TrajectoryCoordinate;
595
596        let a = [1.0, 0.0, 0.0];
597        let b = [0.9, 0.436, 0.0]; // Cosine ~0.9
598
599        let coord_a = TrajectoryCoordinate::new(1, 0, 0.9, 0.2);
600        let coord_b = TrajectoryCoordinate::new(2, 0, 0.8, 0.3);
601
602        // With coord_weight = 0.3, should blend both
603        let sim = trajectory_weighted_cosine(&a, &b, &coord_a, &coord_b, 0.3);
604        assert!(sim > 0.7 && sim < 1.0); // Blended result
605    }
606
607    #[test]
608    fn test_trajectory_weighted_cosine_5d() {
609        use crate::trajectory::TrajectoryCoordinate5D;
610
611        let a = [1.0, 0.0, 0.0];
612        let b = [1.0, 0.0, 0.0];
613
614        let coord_a = TrajectoryCoordinate5D::new(1, 0, 0.9, 0.2, 1);
615        let coord_b = TrajectoryCoordinate5D::new(1, 0, 0.9, 0.2, 3); // Different complexity
616
617        // Identical embeddings, similar coords except complexity
618        let sim = trajectory_weighted_cosine_5d(&a, &b, &coord_a, &coord_b, 0.3);
619        assert!(sim > 0.9); // High similarity despite complexity difference
620    }
621
622    #[test]
623    fn test_trajectory_weighted_l2() {
624        use crate::trajectory::TrajectoryCoordinate;
625
626        let a = [1.0, 0.0, 0.0];
627        let b = [1.0, 0.0, 0.0]; // Identical
628
629        let coord_a = TrajectoryCoordinate::new(0, 0, 1.0, 0.0);
630        let coord_b = TrajectoryCoordinate::new(0, 0, 1.0, 0.0);
631
632        // Both identical - should be 0
633        let dist = trajectory_weighted_l2(&a, &b, &coord_a, &coord_b, 0.5);
634        assert!(dist.abs() < 1e-6);
635    }
636
637    #[test]
638    fn test_trajectory_weighted_inner_product() {
639        use crate::trajectory::TrajectoryCoordinate;
640
641        let a = [1.0, 0.0, 0.0];
642        let b = [1.0, 0.0, 0.0];
643
644        let coord_a = TrajectoryCoordinate::new(1, 0, 0.9, 0.2);
645        let coord_b = TrajectoryCoordinate::new(1, 0, 0.9, 0.2);
646
647        let sim = trajectory_weighted_inner_product(&a, &b, &coord_a, &coord_b, 0.5);
648        // Identical vectors: IP=1.0, coords identical: coord_sim=4.0
649        // Result should be 0.5 * 1.0 + 0.5 * 4.0 = 2.5
650        assert!((sim - 2.5).abs() < 1e-6);
651    }
652
653    #[test]
654    fn test_trajectory_weighted_distance_dispatch() {
655        use crate::trajectory::TrajectoryCoordinate;
656
657        let a = [1.0, 0.0, 0.0];
658        let b = [0.9, 0.436, 0.0];
659
660        let coord_a = TrajectoryCoordinate::new(1, 0, 0.9, 0.2);
661        let coord_b = TrajectoryCoordinate::new(2, 0, 0.8, 0.3);
662
663        // Test all distance types dispatch correctly
664        let _ = trajectory_weighted_distance(&a, &b, &coord_a, &coord_b, DistanceType::L2, 0.3);
665        let _ = trajectory_weighted_distance(&a, &b, &coord_a, &coord_b, DistanceType::InnerProduct, 0.3);
666        let _ = trajectory_weighted_distance(&a, &b, &coord_a, &coord_b, DistanceType::Cosine, 0.3);
667    }
668
669    #[test]
670    fn test_trajectory_distance_config_presets() {
671        let semantic = TrajectoryDistanceConfig::semantic_only();
672        assert_eq!(semantic.coord_weight, 0.0);
673
674        let balanced = TrajectoryDistanceConfig::balanced();
675        assert_eq!(balanced.coord_weight, 0.5);
676
677        let trajectory = TrajectoryDistanceConfig::trajectory_focused();
678        assert_eq!(trajectory.coord_weight, 0.7);
679    }
680
681    #[test]
682    fn test_trajectory_distance_config_defaults() {
683        let config = TrajectoryDistanceConfig::default();
684        assert_eq!(config.coord_weight, 0.2);
685        assert!(config.phase_boost);
686        assert_eq!(config.phase_boost_amount, 0.1);
687    }
688
689    #[test]
690    fn test_coord_weight_clamping() {
691        use crate::trajectory::TrajectoryCoordinate;
692
693        let a = [1.0, 0.0, 0.0];
694        let b = [1.0, 0.0, 0.0];
695
696        let coord_a = TrajectoryCoordinate::new(0, 0, 1.0, 0.0);
697        let coord_b = TrajectoryCoordinate::new(0, 0, 1.0, 0.0);
698
699        // Weight > 1.0 should be clamped to 1.0
700        let sim_over = trajectory_weighted_cosine(&a, &b, &coord_a, &coord_b, 5.0);
701        let sim_one = trajectory_weighted_cosine(&a, &b, &coord_a, &coord_b, 1.0);
702        assert!((sim_over - sim_one).abs() < 1e-6);
703
704        // Weight < 0.0 should be clamped to 0.0
705        let sim_under = trajectory_weighted_cosine(&a, &b, &coord_a, &coord_b, -2.0);
706        let sim_zero = trajectory_weighted_cosine(&a, &b, &coord_a, &coord_b, 0.0);
707        assert!((sim_under - sim_zero).abs() < 1e-6);
708    }
709}