boostr 0.1.0

ML framework built on numr - attention, quantization, model architectures
Documentation
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
use std::io::Write;
use tempfile::NamedTempFile;

fn create_test_file() -> NamedTempFile {
    let mut file = NamedTempFile::new().unwrap();

    let header = serde_json::json!({
        "__metadata__": { "format": "pt" },
        "weight": {
            "dtype": "F32",
            "shape": [2, 3],
            "data_offsets": [0, 24]
        }
    });
    let header_str = header.to_string();
    let header_bytes = header_str.as_bytes();

    file.write_all(&(header_bytes.len() as u64).to_le_bytes())
        .unwrap();
    file.write_all(header_bytes).unwrap();

    for f in [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] {
        file.write_all(&f.to_le_bytes()).unwrap();
    }
    file.flush().unwrap();
    file
}

#[test]
fn test_open_and_metadata() {
    let f = create_test_file();
    let st = SafeTensors::open(f.path()).unwrap();
    assert_eq!(st.len(), 1);
    assert_eq!(st.metadata().get("format"), Some(&"pt".to_string()));
}

#[test]
fn test_tensor_info() {
    let f = create_test_file();
    let st = SafeTensors::open(f.path()).unwrap();
    let info = st.tensor_info("weight").unwrap();
    assert_eq!(info.dtype, DType::F32);
    assert_eq!(info.shape, vec![2, 3]);
    assert_eq!(info.numel(), 6);
    assert_eq!(info.size_bytes(), 24);
}

#[test]
fn test_load_tensor_f32() {
    let (_, device) = cpu_setup();
    let f = create_test_file();
    let mut st = SafeTensors::open(f.path()).unwrap();
    let tensor = st.load_tensor::<CpuRuntime>("weight", &device).unwrap();
    assert_eq!(tensor.shape(), &[2, 3]);
    let data = tensor.to_vec::<f32>();
    assert!((data[0] - 1.0).abs() < 1e-6);
    assert!((data[5] - 6.0).abs() < 1e-6);
}

fn create_test_file_bf16() -> NamedTempFile {
    let mut file = NamedTempFile::new().unwrap();

    let header = serde_json::json!({
        "__metadata__": { "format": "pt" },
        "weight": {
            "dtype": "BF16",
            "shape": [2, 3],
            "data_offsets": [0, 12]
        }
    });
    let header_str = header.to_string();
    let header_bytes = header_str.as_bytes();

    file.write_all(&(header_bytes.len() as u64).to_le_bytes())
        .unwrap();
    file.write_all(header_bytes).unwrap();

    for f in [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] {
        file.write_all(&half::bf16::from_f32(f).to_le_bytes())
            .unwrap();
    }
    file.flush().unwrap();
    file
}

#[test]
fn test_load_tensor_bf16() {
    let (_, device) = cpu_setup();
    let f = create_test_file_bf16();
    let mut st = SafeTensors::open(f.path()).unwrap();
    let tensor = st.load_tensor::<CpuRuntime>("weight", &device).unwrap();
    assert_eq!(tensor.shape(), &[2, 3]);
    assert_eq!(tensor.dtype(), DType::BF16);
    let data: Vec<half::bf16> = tensor.to_vec();
    assert!((data[0].to_f32() - 1.0).abs() < 1e-2);
    assert!((data[5].to_f32() - 6.0).abs() < 1e-2);
}

#[test]
fn test_tensor_not_found() {
    let f = create_test_file();
    let st = SafeTensors::open(f.path()).unwrap();
    assert!(st.tensor_info("nonexistent").is_err());
}

#[test]
fn test_save_and_load_roundtrip() {
    let (_, device) = cpu_setup();
    let tmp = NamedTempFile::new().unwrap();

    let mut tensors = HashMap::new();
    tensors.insert(
        "w1".to_string(),
        Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2], &device),
    );

    save_safetensors(tmp.path(), &tensors, None).unwrap();

    let mut loaded = SafeTensors::open(tmp.path()).unwrap();
    assert_eq!(loaded.len(), 1);
    let t = loaded.load_tensor::<CpuRuntime>("w1", &device).unwrap();
    assert_eq!(t.shape(), &[2, 2]);
    let data = t.to_vec::<f32>();
    assert!((data[0] - 1.0).abs() < 1e-6);
    assert!((data[3] - 4.0).abs() < 1e-6);
}