innr 0.4.0

SIMD-accelerated vector similarity primitives with binary, ternary, and scalar quantization
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
//! Property-based tests for SIMD correctness.
//!
//! These tests verify that SIMD implementations produce identical results
//! to portable implementations across various input sizes and value ranges.

use proptest::prelude::*;

// Reference portable implementations for comparison
fn dot_reference(a: &[f32], b: &[f32]) -> f32 {
    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}

fn l2_sq_reference(a: &[f32], b: &[f32]) -> f32 {
    a.iter()
        .zip(b.iter())
        .map(|(x, y)| {
            let d = x - y;
            d * d
        })
        .sum()
}

fn norm_reference(v: &[f32]) -> f32 {
    dot_reference(v, v).sqrt()
}

/// Generate vectors of specific sizes to test SIMD boundaries.
fn arb_vec_pair(len: usize) -> impl Strategy<Value = (Vec<f32>, Vec<f32>)> {
    (
        proptest::collection::vec(-100.0f32..100.0, len),
        proptest::collection::vec(-100.0f32..100.0, len),
    )
}

proptest! {
    #![proptest_config(ProptestConfig {
        cases: 500,
        ..ProptestConfig::default()
    })]

    // ─────────────────────────────────────────────────────────────────────────
    // Dot product tests
    // ─────────────────────────────────────────────────────────────────────────

    /// Dot product matches reference for small vectors (below SIMD threshold).
    ///
    /// Tolerance scales with the sum of |products| rather than |result| to
    /// handle cancellation: when positive and negative terms cancel, |result|
    /// can be tiny while intermediate magnitudes are large. The 4-way
    /// accumulator in dot_portable uses a different summation order than the
    /// sequential reference, producing O(n * eps * sum|a_i*b_i|) error.
    #[test]
    fn dot_small_matches_reference((a, b) in arb_vec_pair(8)) {
        let result = innr::dot(&a, &b);
        let expected = dot_reference(&a, &b);
        let sum_abs_products: f32 = a.iter().zip(b.iter())
            .map(|(x, y)| (x * y).abs())
            .sum();
        let tolerance = sum_abs_products * 1e-5 + 1e-4;
        prop_assert!(
            (result - expected).abs() < tolerance,
            "Small dot mismatch: {} vs {} (diff: {}, tol: {})",
            result, expected, (result - expected).abs(), tolerance
        );
    }

    /// Dot product matches reference for medium vectors (SIMD active).
    ///
    /// Note: SIMD implementations may accumulate in different orders than
    /// sequential code, leading to different rounding errors.
    #[test]
    fn dot_medium_matches_reference((a, b) in arb_vec_pair(64)) {
        let result = innr::dot(&a, &b);
        let expected = dot_reference(&a, &b);
        // Tolerance scales with sum of |products| to handle cancellation
        let sum_abs_products: f32 = a.iter().zip(b.iter())
            .map(|(x, y)| (x * y).abs())
            .sum();
        let tolerance = sum_abs_products * 1e-5 + 1e-3;
        prop_assert!(
            (result - expected).abs() < tolerance,
            "Medium dot mismatch: {} vs {} (diff: {}, tol: {})",
            result, expected, (result - expected).abs(), tolerance
        );
    }

    /// Dot product matches reference for large vectors.
    ///
    /// With 256 elements, accumulation order differences become more significant.
    /// The tolerance must account for accumulated error across many operations.
    /// With values in [-100, 100] and 256 elements, max possible |dot| is ~2.56M.
    /// We use an absolute tolerance that scales with the magnitude of inputs.
    #[test]
    fn dot_large_matches_reference((a, b) in arb_vec_pair(256)) {
        let result = innr::dot(&a, &b);
        let expected = dot_reference(&a, &b);
        // Estimate the magnitude of intermediate products
        let sum_abs_products: f32 = a.iter().zip(b.iter())
            .map(|(x, y)| (x * y).abs())
            .sum();
        // Tolerance scales with sum of |products|, not the potentially small |expected|
        let tolerance = sum_abs_products * 1e-5 + 1e-2;
        prop_assert!(
            (result - expected).abs() < tolerance,
            "Large dot mismatch: {} vs {} (diff: {}, tol: {})",
            result, expected, (result - expected).abs(), tolerance
        );
    }

    /// Dot product is commutative.
    #[test]
    fn dot_commutative((a, b) in arb_vec_pair(128)) {
        let ab = innr::dot(&a, &b);
        let ba = innr::dot(&b, &a);
        prop_assert!(
            (ab - ba).abs() < 1e-6,
            "Dot not commutative: {} != {}",
            ab, ba
        );
    }

    // ─────────────────────────────────────────────────────────────────────────
    // L2 distance squared tests
    // ─────────────────────────────────────────────────────────────────────────

    /// L2 squared matches reference for small vectors.
    #[test]
    fn l2_sq_small_matches_reference((a, b) in arb_vec_pair(8)) {
        let result = innr::l2_distance_squared(&a, &b);
        let expected = l2_sq_reference(&a, &b);
        let tolerance = expected.abs() * 1e-5 + 1e-6;
        prop_assert!(
            (result - expected).abs() < tolerance,
            "Small L2sq mismatch: {} vs {}",
            result, expected
        );
    }

    /// L2 squared matches reference for medium vectors.
    #[test]
    fn l2_sq_medium_matches_reference((a, b) in arb_vec_pair(64)) {
        let result = innr::l2_distance_squared(&a, &b);
        let expected = l2_sq_reference(&a, &b);
        let tolerance = expected.abs() * 1e-4 + 1e-5;
        prop_assert!(
            (result - expected).abs() < tolerance,
            "Medium L2sq mismatch: {} vs {}",
            result, expected
        );
    }

    /// L2 squared matches reference for large vectors.
    #[test]
    fn l2_sq_large_matches_reference((a, b) in arb_vec_pair(256)) {
        let result = innr::l2_distance_squared(&a, &b);
        let expected = l2_sq_reference(&a, &b);
        let tolerance = expected.abs() * 1e-4 + 1e-4;
        prop_assert!(
            (result - expected).abs() < tolerance,
            "Large L2sq mismatch: {} vs {}",
            result, expected
        );
    }

    /// L2 squared is symmetric.
    #[test]
    fn l2_sq_symmetric((a, b) in arb_vec_pair(128)) {
        let ab = innr::l2_distance_squared(&a, &b);
        let ba = innr::l2_distance_squared(&b, &a);
        prop_assert!(
            (ab - ba).abs() < 1e-6,
            "L2sq not symmetric: {} != {}",
            ab, ba
        );
    }

    /// L2 squared is non-negative.
    #[test]
    fn l2_sq_nonnegative((a, b) in arb_vec_pair(128)) {
        let result = innr::l2_distance_squared(&a, &b);
        prop_assert!(
            result >= 0.0,
            "L2sq should be non-negative, got {}",
            result
        );
    }

    /// L2 squared to self is zero.
    #[test]
    fn l2_sq_self_is_zero(v in proptest::collection::vec(-100.0f32..100.0, 128)) {
        let result = innr::l2_distance_squared(&v, &v);
        prop_assert!(
            result.abs() < 1e-6,
            "L2sq to self should be 0, got {}",
            result
        );
    }

    // ─────────────────────────────────────────────────────────────────────────
    // Cosine similarity tests
    // ─────────────────────────────────────────────────────────────────────────

    /// Cosine similarity is bounded [-1, 1].
    #[test]
    fn cosine_bounded(
        (a, b) in arb_vec_pair(128).prop_filter("non-zero", |(a, b)| {
            a.iter().any(|x| x.abs() > 1e-6) && b.iter().any(|x| x.abs() > 1e-6)
        })
    ) {
        let result = innr::cosine(&a, &b);
        prop_assert!(
            (-1.0 - 1e-5..=1.0 + 1e-5).contains(&result),
            "Cosine out of bounds: {}",
            result
        );
    }

    /// Cosine similarity is symmetric.
    #[test]
    fn cosine_symmetric((a, b) in arb_vec_pair(128)) {
        let ab = innr::cosine(&a, &b);
        let ba = innr::cosine(&b, &a);
        prop_assert!(
            (ab - ba).abs() < 1e-5,
            "Cosine not symmetric: {} != {}",
            ab, ba
        );
    }

    /// Cosine of vector with itself is 1.
    #[test]
    fn cosine_self_is_one(
        v in proptest::collection::vec(-100.0f32..100.0, 128)
            .prop_filter("non-zero", |v| v.iter().any(|x| x.abs() > 1e-6))
    ) {
        let result = innr::cosine(&v, &v);
        prop_assert!(
            (result - 1.0).abs() < 1e-5,
            "Cosine of self should be 1.0, got {}",
            result
        );
    }

    // ─────────────────────────────────────────────────────────────────────────
    // Norm tests
    // ─────────────────────────────────────────────────────────────────────────

    /// Norm is non-negative.
    #[test]
    fn norm_nonnegative(v in proptest::collection::vec(-100.0f32..100.0, 128)) {
        let result = innr::norm(&v);
        prop_assert!(result >= 0.0, "Norm should be non-negative, got {}", result);
    }

    /// Norm matches reference.
    #[test]
    fn norm_matches_reference(v in proptest::collection::vec(-100.0f32..100.0, 128)) {
        let result = innr::norm(&v);
        let expected = norm_reference(&v);
        let tolerance = expected.abs() * 1e-5 + 1e-6;
        prop_assert!(
            (result - expected).abs() < tolerance,
            "Norm mismatch: {} vs {}",
            result, expected
        );
    }

    /// Norm scales linearly with scalar multiplication.
    #[test]
    fn norm_scales_with_scalar(
        v in proptest::collection::vec(-10.0f32..10.0, 64),
        scale in 0.1f32..10.0
    ) {
        let scaled: Vec<f32> = v.iter().map(|x| x * scale).collect();
        let norm_v = innr::norm(&v);
        let norm_scaled = innr::norm(&scaled);
        let expected = norm_v * scale;
        let tolerance = expected.abs() * 1e-4 + 1e-5;
        prop_assert!(
            (norm_scaled - expected).abs() < tolerance,
            "Norm scaling violated: {} != {} * {} = {}",
            norm_scaled, norm_v, scale, expected
        );
    }
}

// ─────────────────────────────────────────────────────────────────────────────
// SIMD boundary tests (specific sizes that stress SIMD implementations)
// ─────────────────────────────────────────────────────────────────────────────

#[test]
fn test_dot_at_simd_boundaries() {
    use rand::Rng;
    let mut rng = rand::rng();

    // Test sizes that are: exact SIMD width, SIMD width - 1, SIMD width + 1
    // AVX-512: 16 floats, AVX2: 8 floats, NEON: 4 floats
    let sizes = [
        1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 23, 24, 25, 31, 32, 33, 47, 48, 49, 63, 64, 65, 127, 128,
        129, 255, 256, 257,
    ];

    for &size in &sizes {
        let a: Vec<f32> = (0..size).map(|_| rng.random_range(-10.0..10.0)).collect();
        let b: Vec<f32> = (0..size).map(|_| rng.random_range(-10.0..10.0)).collect();

        let result = innr::dot(&a, &b);
        let expected = dot_reference(&a, &b);

        // FMA vs non-FMA rounding compounds over many accumulations;
        // 5e-4 relative tolerance accommodates the worst case at dim=257.
        let tolerance = expected.abs() * 5e-4 + 1e-5;
        assert!(
            (result - expected).abs() < tolerance,
            "Dot at size {}: {} vs {} (diff: {})",
            size,
            result,
            expected,
            (result - expected).abs()
        );
    }
}

#[test]
fn test_l2_sq_at_simd_boundaries() {
    use rand::Rng;
    let mut rng = rand::rng();

    let sizes = [
        1, 3, 4, 5, 7, 8, 9, 15, 16, 17, 23, 24, 25, 31, 32, 33, 47, 48, 49, 63, 64, 65, 127, 128,
        129, 255, 256, 257,
    ];

    for &size in &sizes {
        let a: Vec<f32> = (0..size).map(|_| rng.random_range(-10.0..10.0)).collect();
        let b: Vec<f32> = (0..size).map(|_| rng.random_range(-10.0..10.0)).collect();

        let result = innr::l2_distance_squared(&a, &b);
        let expected = l2_sq_reference(&a, &b);

        let tolerance = expected.abs() * 1e-4 + 1e-5;
        assert!(
            (result - expected).abs() < tolerance,
            "L2sq at size {}: {} vs {} (diff: {})",
            size,
            result,
            expected,
            (result - expected).abs()
        );
    }
}

// =============================================================================
// Batch operation property tests
// =============================================================================

mod batch_props {
    use super::*;
    use innr::batch::{
        batch_cosine, batch_dot, batch_knn, batch_l2_squared, batch_norms, VerticalBatch,
    };

    /// Generate a batch of vectors.
    fn arb_batch(
        num_vectors: usize,
        dim: usize,
    ) -> impl Strategy<Value = (Vec<Vec<f32>>, Vec<f32>)> {
        (
            proptest::collection::vec(proptest::collection::vec(-10.0f32..10.0, dim), num_vectors),
            proptest::collection::vec(-10.0f32..10.0, dim),
        )
    }

    proptest! {
        #![proptest_config(ProptestConfig::with_cases(200))]

        /// Batch L2 squared matches individual computations.
        #[test]
        fn batch_l2_matches_individual((vectors, query) in arb_batch(20, 32)) {
            let batch = VerticalBatch::from_rows(&vectors);
            let batch_results = batch_l2_squared(&query, &batch);

            for (i, vec) in vectors.iter().enumerate() {
                let individual = innr::l2_distance_squared(&query, vec);
                let tolerance = individual.abs() * 1e-4 + 1e-5;
                prop_assert!(
                    (batch_results[i] - individual).abs() < tolerance,
                    "Batch L2 mismatch at {}: {} vs {}",
                    i, batch_results[i], individual
                );
            }
        }

        /// Batch dot matches individual computations.
        /// Note: Batch and individual may accumulate in different orders,
        /// leading to small floating-point differences.
        #[test]
        fn batch_dot_matches_individual((vectors, query) in arb_batch(20, 32)) {
            let batch = VerticalBatch::from_rows(&vectors);
            let batch_results = batch_dot(&query, &batch);

            for (i, vec) in vectors.iter().enumerate() {
                let individual = innr::dot(&query, vec);
                // Tolerance accounts for accumulation order differences
                let sum_abs: f32 = query.iter().zip(vec.iter())
                    .map(|(a, b)| (a * b).abs())
                    .sum();
                let tolerance = sum_abs * 1e-4 + 1e-4;
                prop_assert!(
                    (batch_results[i] - individual).abs() < tolerance,
                    "Batch dot mismatch at {}: {} vs {} (tol: {})",
                    i, batch_results[i], individual, tolerance
                );
            }
        }

        /// Batch norms are non-negative.
        #[test]
        fn batch_norms_nonnegative((vectors, _) in arb_batch(20, 32)) {
            let batch = VerticalBatch::from_rows(&vectors);
            let norms = batch_norms(&batch);

            for (i, &n) in norms.iter().enumerate() {
                prop_assert!(n >= 0.0, "Norm at {} should be non-negative: {}", i, n);
            }
        }

        /// Batch cosine is bounded [-1, 1].
        #[test]
        fn batch_cosine_bounded(
            (vectors, query) in arb_batch(20, 32).prop_filter("non-zero", |(vecs, q)| {
                q.iter().any(|x| x.abs() > 1e-6) &&
                vecs.iter().all(|v| v.iter().any(|x| x.abs() > 1e-6))
            })
        ) {
            let batch = VerticalBatch::from_rows(&vectors);
            let norms = batch_norms(&batch);
            let cosines = batch_cosine(&query, &batch, &norms);

            for (i, &c) in cosines.iter().enumerate() {
                prop_assert!(
                    (-1.0 - 1e-4..=1.0 + 1e-4).contains(&c),
                    "Cosine at {} out of bounds: {}",
                    i, c
                );
            }
        }

        /// kNN returns sorted results.
        #[test]
        fn batch_knn_sorted((vectors, query) in arb_batch(50, 16)) {
            let batch = VerticalBatch::from_rows(&vectors);
            let result = batch_knn(&query, &batch, 10);

            for i in 1..result.scores.len() {
                prop_assert!(
                    result.scores[i] >= result.scores[i - 1] - 1e-6,
                    "kNN not sorted: {} > {} at {}",
                    result.scores[i - 1], result.scores[i], i
                );
            }
        }

        /// kNN indices are unique.
        #[test]
        fn batch_knn_unique_indices((vectors, query) in arb_batch(50, 16)) {
            let batch = VerticalBatch::from_rows(&vectors);
            let result = batch_knn(&query, &batch, 20);

            let unique: std::collections::HashSet<_> = result.indices.iter().collect();
            prop_assert_eq!(
                unique.len(),
                result.indices.len(),
                "kNN indices not unique"
            );
        }

        /// Vertical batch roundtrip preserves data.
        #[test]
        fn batch_roundtrip((vectors, _) in arb_batch(10, 16)) {
            let batch = VerticalBatch::from_rows(&vectors);

            for (i, original) in vectors.iter().enumerate() {
                let extracted = batch.extract_vector(i);
                for (j, (&orig, &ext)) in original.iter().zip(extracted.iter()).enumerate() {
                    prop_assert!(
                        (orig - ext).abs() < 1e-6,
                        "Roundtrip mismatch at [{}, {}]: {} vs {}",
                        i, j, orig, ext
                    );
                }
            }
        }
    }
}

// =============================================================================
// Triangle inequality tests
// =============================================================================

proptest! {
    #![proptest_config(ProptestConfig::with_cases(200))]

    /// L2 distance satisfies triangle inequality.
    #[test]
    fn l2_triangle_inequality(
        a in proptest::collection::vec(-10.0f32..10.0, 32),
        b in proptest::collection::vec(-10.0f32..10.0, 32),
        c in proptest::collection::vec(-10.0f32..10.0, 32),
    ) {
        let ab = innr::l2_distance(&a, &b);
        let bc = innr::l2_distance(&b, &c);
        let ac = innr::l2_distance(&a, &c);

        // ac <= ab + bc (with tolerance for floating point)
        prop_assert!(
            ac <= ab + bc + 1e-4,
            "Triangle inequality violated: {} > {} + {} = {}",
            ac, ab, bc, ab + bc
        );
    }
}