dakera-inference 0.11.81

Embedded inference engine for Dakera - generates embeddings locally via ONNX Runtime
Documentation
//! Integration tests for `crates/inference` batch processing.
//!
//! Covers three regression-prone invariants that inline unit tests cannot catch:
//! 1. **Score monotonicity** — processing item A in batch [A, B] gives the same
//!    pooled vector as processing A alone. PR#476 length-sorted batching violated
//!    this by altering embedding values based on co-batch members. (DAK-5718)
//! 2. **Output ordering** — result[i] always corresponds to input[i] regardless
//!    of batch size or position.
//! 3. **Normalization safety** — zero/near-zero vectors don't produce NaN or Inf,
//!    and unit-length postcondition holds for all well-formed inputs.

use inference::batch::{mean_pooling, normalize_embeddings};

// ─────────────────────────────────────────────────────────────
// mean_pooling — output shape
// ─────────────────────────────────────────────────────────────

#[test]
fn mean_pooling_output_shape_matches_batch_x_hidden() {
    let lhs = vec![0.0f32; 3 * 4 * 5]; // batch=3, seq=4, hidden=5
    let mask = vec![1i64; 3 * 4];
    let result = mean_pooling(&lhs, 3, 4, 5, &mask);
    assert_eq!(result.len(), 3, "should produce one vector per batch item");
    assert!(
        result.iter().all(|v| v.len() == 5),
        "each vector should have hidden_size=5 dims"
    );
}

#[test]
fn mean_pooling_single_batch_known_value() {
    // batch=1, seq=3, hidden=2, all hidden = [4.0, 6.0], mask all-ones
    // Expected: mean = [4.0, 6.0]
    let lhs = vec![4.0f32, 6.0, 4.0, 6.0, 4.0, 6.0];
    let mask = vec![1i64; 3];
    let result = mean_pooling(&lhs, 1, 3, 2, &mask);
    assert!((result[0][0] - 4.0).abs() < 1e-6, "dim0 should be 4.0");
    assert!((result[0][1] - 6.0).abs() < 1e-6, "dim1 should be 6.0");
}

#[test]
fn mean_pooling_masked_tokens_do_not_contribute() {
    // batch=1, seq=2, hidden=2
    // Token 0: [1.0, 1.0] mask=1 (active)
    // Token 1: [9.0, 9.0] mask=0 (padding) — must not affect result
    let lhs = vec![1.0f32, 1.0, 9.0, 9.0];
    let mask = vec![1i64, 0i64];
    let result = mean_pooling(&lhs, 1, 2, 2, &mask);
    assert!(
        (result[0][0] - 1.0).abs() < 1e-6,
        "masked token [9.0] must not shift result; got {}",
        result[0][0]
    );
    assert!((result[0][1] - 1.0).abs() < 1e-6, "got {}", result[0][1]);
}

// ─────────────────────────────────────────────────────────────
// Score monotonicity invariant (guards against PR#476 regression)
// ─────────────────────────────────────────────────────────────

#[test]
fn mean_pooling_batch_position_independence() {
    // Processing item A in batch [A, B] must give the same pooled vector
    // as processing A alone. Violation would mean co-batch members distort
    // embeddings — the exact regression mechanism of PR#476.
    //
    // A: seq=2, hidden=3, hidden=[1.0, 2.0, 3.0], both tokens active
    // B: seq=2, hidden=3, hidden=[5.0, 6.0, 7.0], only first token active
    let lhs_a_alone = vec![
        1.0f32, 2.0, 3.0, // A token 0
        1.0, 2.0, 3.0, // A token 1
    ];
    let mask_a_alone = vec![1i64, 1i64];
    let result_a = mean_pooling(&lhs_a_alone, 1, 2, 3, &mask_a_alone);

    let lhs_pair = vec![
        1.0f32, 2.0, 3.0, // A token 0
        1.0, 2.0, 3.0, // A token 1
        5.0, 6.0, 7.0, // B token 0 (active)
        0.0, 0.0, 0.0, // B token 1 (masked)
    ];
    let mask_pair = vec![1i64, 1i64, 1i64, 0i64];
    let result_pair = mean_pooling(&lhs_pair, 2, 2, 3, &mask_pair);

    for (i, (solo, batched)) in result_a[0].iter().zip(result_pair[0].iter()).enumerate() {
        assert!(
            (solo - batched).abs() < 1e-6,
            "dim {i}: solo={solo} vs batched={batched} — co-batch member B must not affect A"
        );
    }
}

#[test]
fn mean_pooling_output_ordering_preserves_input_order() {
    // batch=3, seq=1, hidden=1 — each batch item has a unique known value
    // Result[i] must equal the value of input item i.
    let lhs = vec![10.0f32, 20.0, 30.0];
    let mask = vec![1i64, 1i64, 1i64];
    let result = mean_pooling(&lhs, 3, 1, 1, &mask);
    assert!(
        (result[0][0] - 10.0).abs() < 1e-6,
        "item 0 should be 10.0, got {}",
        result[0][0]
    );
    assert!(
        (result[1][0] - 20.0).abs() < 1e-6,
        "item 1 should be 20.0, got {}",
        result[1][0]
    );
    assert!(
        (result[2][0] - 30.0).abs() < 1e-6,
        "item 2 should be 30.0, got {}",
        result[2][0]
    );
}

#[test]
fn mean_pooling_all_mask_zero_does_not_panic_or_nan() {
    // All-zero mask — the clamped divisor (max(sum, 1e-9)) prevents divide-by-zero.
    let lhs = vec![1.0f32; 2 * 2]; // batch=1, seq=2, hidden=2
    let mask = vec![0i64, 0i64];
    let result = mean_pooling(&lhs, 1, 2, 2, &mask);
    for v in &result[0] {
        assert!(
            v.is_finite(),
            "all-zero mask should not produce NaN/Inf, got {v}"
        );
    }
}

#[test]
fn mean_pooling_partial_mask_computes_active_average() {
    // batch=1, seq=4, hidden=1: values=[2,4,8,16], mask=[1,1,0,0]
    // Active mean = (2 + 4) / 2 = 3.0
    let lhs = vec![2.0f32, 4.0, 8.0, 16.0];
    let mask = vec![1i64, 1i64, 0i64, 0i64];
    let result = mean_pooling(&lhs, 1, 4, 1, &mask);
    assert!(
        (result[0][0] - 3.0).abs() < 1e-6,
        "expected mean of active tokens = 3.0, got {}",
        result[0][0]
    );
}

// ─────────────────────────────────────────────────────────────
// normalize_embeddings — postconditions
// ─────────────────────────────────────────────────────────────

#[test]
fn normalize_l2_norm_is_one() {
    // [3, 4] → L2 = 5 → normalized = [0.6, 0.8]
    let mut embeddings = vec![vec![3.0f32, 4.0]];
    normalize_embeddings(&mut embeddings);
    let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
    assert!(
        (norm - 1.0).abs() < 1e-5,
        "L2 norm after normalization should be 1.0, got {norm}"
    );
}

#[test]
fn normalize_known_3_4_vector_values() {
    let mut embeddings = vec![vec![3.0f32, 4.0]];
    normalize_embeddings(&mut embeddings);
    assert!(
        (embeddings[0][0] - 0.6).abs() < 1e-5,
        "expected 0.6, got {}",
        embeddings[0][0]
    );
    assert!(
        (embeddings[0][1] - 0.8).abs() < 1e-5,
        "expected 0.8, got {}",
        embeddings[0][1]
    );
}

#[test]
fn normalize_each_row_normalized_independently() {
    // Rows [1,0] and [0,1] should each normalize to themselves.
    let mut embeddings = vec![vec![1.0f32, 0.0], vec![0.0f32, 1.0]];
    normalize_embeddings(&mut embeddings);
    for (i, row) in embeddings.iter().enumerate() {
        let norm: f32 = row.iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!(
            (norm - 1.0).abs() < 1e-5,
            "row {i} should have L2 norm 1.0, got {norm}"
        );
    }
}

#[test]
fn normalize_near_zero_vector_produces_finite_output() {
    // Near-zero vector — the clamped divisor (max(norm, 1e-12)) prevents NaN/Inf.
    let mut embeddings = vec![vec![1e-15f32, 1e-15]];
    normalize_embeddings(&mut embeddings);
    for v in &embeddings[0] {
        assert!(
            v.is_finite(),
            "near-zero vector should not produce NaN/Inf, got {v}"
        );
    }
}

#[test]
fn normalize_negative_values_produce_unit_vector() {
    // Negative components are fine — the norm is still positive.
    let mut embeddings = vec![vec![-3.0f32, -4.0]];
    normalize_embeddings(&mut embeddings);
    let norm: f32 = embeddings[0].iter().map(|x| x * x).sum::<f32>().sqrt();
    assert!(
        (norm - 1.0).abs() < 1e-5,
        "negative input should still normalize to unit vector, got norm={norm}"
    );
}

#[test]
fn normalize_already_unit_vector_unchanged() {
    // [0.6, 0.8] is already a unit vector — normalization must be idempotent.
    let mut embeddings = vec![vec![0.6f32, 0.8]];
    normalize_embeddings(&mut embeddings);
    assert!(
        (embeddings[0][0] - 0.6).abs() < 1e-5,
        "idempotent: expected 0.6, got {}",
        embeddings[0][0]
    );
    assert!(
        (embeddings[0][1] - 0.8).abs() < 1e-5,
        "idempotent: expected 0.8, got {}",
        embeddings[0][1]
    );
}

#[test]
fn normalize_preserves_output_shape() {
    let mut embeddings: Vec<Vec<f32>> = (0..4).map(|_| vec![1.0f32, 2.0, 3.0, 4.0, 5.0]).collect();
    normalize_embeddings(&mut embeddings);
    assert_eq!(embeddings.len(), 4, "row count unchanged");
    assert!(
        embeddings.iter().all(|v| v.len() == 5),
        "column count unchanged"
    );
}