trueno 0.18.0

High-performance SIMD compute library with GPU support, LLM inference engine, and GGUF model loading
use super::*;

#[test]
fn test_matvec_single_element_1x1() {
    let m = Matrix::from_vec_with_backend(1, 1, vec![3.0], Backend::Scalar);
    let v = Vector::from_slice(&[5.0]);
    let result = m.matvec(&v).unwrap();
    assert!((result.as_slice()[0] - 15.0).abs() < 1e-6);
}

#[test]
fn test_matvec_single_row() {
    let m = Matrix::from_vec_with_backend(1, 4, vec![1.0, 2.0, 3.0, 4.0], Backend::Scalar);
    let v = Vector::from_slice(&[1.0, 1.0, 1.0, 1.0]);
    let result = m.matvec(&v).unwrap();
    assert!((result.as_slice()[0] - 10.0).abs() < 1e-6);
}

#[test]
fn test_matvec_single_column() {
    let m = Matrix::from_vec_with_backend(3, 1, vec![2.0, 4.0, 6.0], Backend::Scalar);
    let v = Vector::from_slice(&[3.0]);
    let result = m.matvec(&v).unwrap();
    assert_eq!(result.as_slice(), &[6.0, 12.0, 18.0]);
}

#[test]
fn test_matvec_identity_matrix() {
    let m = Matrix::from_vec(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap();
    let v = Vector::from_slice(&[7.0, 11.0, 13.0]);
    let result = m.matvec(&v).unwrap();
    for (i, (&got, &exp)) in result.as_slice().iter().zip([7.0, 11.0, 13.0].iter()).enumerate() {
        assert!((got - exp).abs() < 1e-6, "identity matvec [{i}]: {got} != {exp}");
    }
}

#[test]
fn test_matvec_zero_matrix() {
    let m = Matrix::from_vec(2, 3, vec![0.0; 6]).unwrap();
    let v = Vector::from_slice(&[1.0, 2.0, 3.0]);
    let result = m.matvec(&v).unwrap();
    for &val in result.as_slice() {
        assert!((val - 0.0).abs() < 1e-6);
    }
}

#[test]
fn test_matvec_zero_vector() {
    let m = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
    let v = Vector::from_slice(&[0.0, 0.0, 0.0]);
    let result = m.matvec(&v).unwrap();
    for &val in result.as_slice() {
        assert!((val - 0.0).abs() < 1e-6);
    }
}

#[test]
fn test_matvec_dimension_mismatch_error_message() {
    let m = Matrix::from_vec(2, 3, vec![1.0; 6]).unwrap();
    let v = Vector::from_slice(&[1.0, 2.0]);
    let err = m.matvec(&v).unwrap_err();
    match err {
        TruenoError::InvalidInput(msg) => {
            assert!(msg.contains('2'), "Error should mention vector length 2: {msg}");
            assert!(msg.contains('3'), "Error should mention matrix cols 3: {msg}");
        }
        other => panic!("Expected InvalidInput, got {other:?}"),
    }
}

#[test]
fn test_matvec_negative_values() {
    let m = Matrix::from_vec_with_backend(
        2,
        3,
        vec![-1.0, -2.0, -3.0, -4.0, -5.0, -6.0],
        Backend::Scalar,
    );
    let v = Vector::from_slice(&[-1.0, -2.0, -3.0]);
    let result = m.matvec(&v).unwrap();
    assert!((result.as_slice()[0] - 14.0).abs() < 1e-5);
    assert!((result.as_slice()[1] - 32.0).abs() < 1e-5);
}