xn 0.1.1

Another minimalist deep-learning framework optimized for inference.
Documentation
use std::collections::HashMap;
use xn::{CPU, CpuTensor, DType, Tensor, TypedTensor, safetensors};

#[test]
fn save_and_load_single_tensor() {
    let path = std::env::temp_dir().join("test_single.safetensors");
    let t: CpuTensor<f32> = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], (2, 2), &CPU).unwrap();
    t.save_safetensors("t", &path).unwrap();

    let tensors = safetensors::load_from_file(&path, &CPU).unwrap();
    assert_eq!(tensors.len(), 1);
    let loaded = match tensors.get("t").unwrap() {
        TypedTensor::F32(t) => t,
        _ => panic!("expected F32"),
    };
    assert_eq!(loaded.dims(), &[2, 2]);
    assert_eq!(loaded.to_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);

    std::fs::remove_file(&path).unwrap();
}

#[test]
fn save_and_load_multiple_tensors() {
    let path = std::env::temp_dir().join("test_multi.safetensors");
    let a: CpuTensor<f32> = Tensor::from_vec(vec![1.0, 2.0, 3.0], (1, 3), &CPU).unwrap();
    let b: CpuTensor<i64> = Tensor::from_vec(vec![10, 20], (2,), &CPU).unwrap();

    let mut map: HashMap<String, TypedTensor<_>> = HashMap::new();
    map.insert("a".to_string(), TypedTensor::F32(a));
    map.insert("b".to_string(), TypedTensor::I64(b));
    safetensors::save(&map, &path).unwrap();

    let loaded = safetensors::load_from_file(&path, &CPU).unwrap();
    assert_eq!(loaded.len(), 2);

    let la = match loaded.get("a").unwrap() {
        TypedTensor::F32(t) => t,
        _ => panic!("expected F32"),
    };
    assert_eq!(la.dims(), &[1, 3]);
    assert_eq!(la.to_vec().unwrap(), vec![1.0, 2.0, 3.0]);

    let lb = match loaded.get("b").unwrap() {
        TypedTensor::I64(t) => t,
        _ => panic!("expected I64"),
    };
    assert_eq!(lb.dims(), &[2]);
    assert_eq!(lb.to_vec().unwrap(), vec![10, 20]);

    std::fs::remove_file(&path).unwrap();
}

#[test]
fn save_with_str_keys() {
    let path = std::env::temp_dir().join("test_str_keys.safetensors");
    let t: CpuTensor<f32> = Tensor::from_vec(vec![5.0, 6.0], (2,), &CPU).unwrap();

    let map: HashMap<&str, TypedTensor<_>> = [("x", TypedTensor::F32(t))].into_iter().collect();
    safetensors::save(&map, &path).unwrap();

    let loaded = safetensors::load_from_file(&path, &CPU).unwrap();
    let lx = match loaded.get("x").unwrap() {
        TypedTensor::F32(t) => t,
        _ => panic!("expected F32"),
    };
    assert_eq!(lx.to_vec().unwrap(), vec![5.0, 6.0]);

    std::fs::remove_file(&path).unwrap();
}

#[test]
fn roundtrip_preserves_dtypes() {
    let path = std::env::temp_dir().join("test_dtypes.safetensors");

    let f16_t: CpuTensor<half::f16> =
        Tensor::from_vec(vec![half::f16::from_f32(1.0)], (1,), &CPU).unwrap();
    let bf16_t: CpuTensor<half::bf16> =
        Tensor::from_vec(vec![half::bf16::from_f32(2.0)], (1,), &CPU).unwrap();
    let u8_t: CpuTensor<u8> = Tensor::from_vec(vec![42], (1,), &CPU).unwrap();

    let mut map: HashMap<String, TypedTensor<_>> = HashMap::new();
    map.insert("f16".to_string(), TypedTensor::F16(f16_t));
    map.insert("bf16".to_string(), TypedTensor::BF16(bf16_t));
    map.insert("u8".to_string(), TypedTensor::U8(u8_t));
    safetensors::save(&map, &path).unwrap();

    let loaded = safetensors::load_from_file(&path, &CPU).unwrap();

    assert_eq!(loaded.get("f16").unwrap().dtype(), DType::F16);
    assert_eq!(loaded.get("bf16").unwrap().dtype(), DType::BF16);
    assert_eq!(loaded.get("u8").unwrap().dtype(), DType::U8);

    let lf16 = match loaded.get("f16").unwrap() {
        TypedTensor::F16(t) => t,
        _ => panic!("expected F16"),
    };
    assert_eq!(lf16.to_vec().unwrap(), vec![half::f16::from_f32(1.0)]);

    let lu8 = match loaded.get("u8").unwrap() {
        TypedTensor::U8(t) => t,
        _ => panic!("expected U8"),
    };
    assert_eq!(lu8.to_vec().unwrap(), vec![42]);

    std::fs::remove_file(&path).unwrap();
}

#[test]
fn save_load_buffer_roundtrip() {
    let path = std::env::temp_dir().join("test_buffer.safetensors");
    let t: CpuTensor<f32> =
        Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], (2, 3), &CPU).unwrap();
    t.save_safetensors("t", &path).unwrap();

    let bytes = std::fs::read(&path).unwrap();
    let loaded = safetensors::load_from_buffer(&bytes, &CPU).unwrap();
    let lt = match loaded.get("t").unwrap() {
        TypedTensor::F32(t) => t,
        _ => panic!("expected F32"),
    };
    assert_eq!(lt.dims(), &[2, 3]);
    assert_eq!(lt.to_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);

    std::fs::remove_file(&path).unwrap();
}