svod-tensor 0.1.0-alpha.3

High-level lazy tensor API for the Svod ML compiler
Documentation
use svod_dtype::DType;

use crate::Tensor;

crate::codegen_tests! {
    fn test_from_raw_bytes_f32(config) {
        let values: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
        let bytes: Vec<u8> = values.iter().flat_map(|v| v.to_le_bytes()).collect();
        let t = Tensor::from_raw_bytes(&bytes, &[2, 3], DType::Float32).unwrap();
        let shape = t.shape().unwrap();
        assert_eq!(shape.len(), 2);
        assert_eq!(shape[0].as_const().unwrap(), 2);
        assert_eq!(shape[1].as_const().unwrap(), 3);
        let mut t = t;
        t.realize_with(&config).unwrap();
        assert_eq!(t.as_vec::<f32>().unwrap(), [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
    }

    fn test_from_raw_bytes_f16(config) {
        let f16_bits: Vec<u16> = vec![0x3C00, 0x4000];
        let bytes: Vec<u8> = f16_bits.iter().flat_map(|v| v.to_le_bytes()).collect();
        let t = Tensor::from_raw_bytes(&bytes, &[2], DType::Float16).unwrap();
        assert_eq!(t.uop().dtype(), DType::Float16);

        let mut t_f32 = t.cast(DType::Float32).unwrap();
        t_f32.realize_with(&config).unwrap();
        let vals = t_f32.as_vec::<f32>().unwrap();
        assert!((vals[0] - 1.0).abs() < 1e-3);
        assert!((vals[1] - 2.0).abs() < 1e-3);
    }

    fn test_eye_square(config) {
        let eye = Tensor::eye(3, 3, DType::Float32).unwrap();
        let shape = eye.shape().unwrap();
        assert_eq!(shape.len(), 2);
        assert_eq!(shape[0].as_const().unwrap(), 3);
        assert_eq!(shape[1].as_const().unwrap(), 3);
        let mut eye = eye;
        eye.realize_with(&config).unwrap();
        assert_eq!(
            eye.as_vec::<f32>().unwrap(),
            [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
        );
    }

    fn test_eye_rectangular(config) {
        let eye = Tensor::eye(2, 4, DType::Float32).unwrap();
        let shape = eye.shape().unwrap();
        assert_eq!(shape.len(), 2);
        assert_eq!(shape[0].as_const().unwrap(), 2);
        assert_eq!(shape[1].as_const().unwrap(), 4);
        let mut eye = eye;
        eye.realize_with(&config).unwrap();
        assert_eq!(
            eye.as_vec::<f32>().unwrap(),
            [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]
        );
    }

    fn test_eye_single(config) {
        let mut eye = Tensor::eye(1, 1, DType::Float32).unwrap();
        eye.realize_with(&config).unwrap();
        let view = eye.array_view::<f32>().unwrap();
        assert_eq!(view[[0, 0]], 1.0);
    }
}

#[test]
fn test_from_raw_bytes_wrong_length() {
    let bytes = vec![0u8; 10];
    let result = Tensor::from_raw_bytes(&bytes, &[3], DType::Float32);
    assert!(result.is_err());
    let err = result.err().unwrap().to_string();
    assert!(err.contains("from_raw_bytes"), "Error should mention from_raw_bytes: {err}");
}