#![cfg(feature = "all-formats")]
use std::collections::BTreeMap;
use std::io::{Cursor, Seek, SeekFrom};
use half::f16;
use ztensor::{DType, PyTorchReader, Reader, SafeTensorsReader, TensorReader, Writer};
mod common;
use common::data_generators::*;
use common::pytorch_builder::*;
use common::safetensors_builder::*;
fn cross_format_verify_f32(name: &str, shape_zt: Vec<u64>, shape_st: Vec<usize>, data: &[f32]) {
let raw_bytes: Vec<u8> = bytemuck::cast_slice(data).to_vec();
let n: usize = data.len();
let mut zt_buf = Cursor::new(Vec::new());
let mut w = Writer::new(&mut zt_buf).unwrap();
w.add(name, &shape_zt, data).unwrap();
w.finish().unwrap();
zt_buf.seek(SeekFrom::Start(0)).unwrap();
let zt_reader = Reader::new(&mut zt_buf).unwrap();
let zt_data: Vec<f32> = zt_reader.read_as(name).unwrap();
let st_file = build_safetensors_file(vec![(
name.into(),
safetensors::Dtype::F32,
shape_st,
raw_bytes.clone(),
)]);
let st_reader = SafeTensorsReader::open(st_file.path()).unwrap();
let st_data: Vec<f32> = st_reader.read_as(name).unwrap();
let shape_pt: Vec<usize> = zt_data
.len()
.min(n)
.max(1)
.min(n)
.checked_div(1) .map(|_| vec![n])
.unwrap();
let stride_pt = compute_strides(&shape_pt);
let specs = vec![PtTensorSpec {
name: name.into(),
storage_type: "FloatStorage".into(),
storage_key: "0".into(),
shape: shape_pt,
stride: stride_pt,
storage_offset: 0,
numel: n,
}];
let mut storage = BTreeMap::new();
storage.insert("0".into(), raw_bytes);
let pt_file = build_pytorch_zip(&specs, &storage);
let pt_reader = PyTorchReader::open(pt_file.path()).unwrap();
let pt_data: Vec<f32> = pt_reader.read_as(name).unwrap();
assert_eq!(zt_data, data, "ZTensor data mismatch");
assert_eq!(st_data, data, "SafeTensors data mismatch");
assert_eq!(pt_data, data, "PyTorch data mismatch");
}
#[test]
fn cross_f32_2d() {
let data = make_f32_data(64 * 128);
cross_format_verify_f32("matrix", vec![64, 128], vec![64, 128], &data);
}
#[test]
fn cross_f16_all_formats() {
let data = make_f16_data(256);
let raw_bytes: Vec<u8> = bytemuck::cast_slice(&data).to_vec();
let mut zt_buf = Cursor::new(Vec::new());
let mut w = Writer::new(&mut zt_buf).unwrap();
w.add("t", &[256], &data).unwrap();
w.finish().unwrap();
zt_buf.seek(SeekFrom::Start(0)).unwrap();
let zt_reader = Reader::new(&mut zt_buf).unwrap();
let zt_result: Vec<f16> = zt_reader.read_as("t").unwrap();
let st_file = build_safetensors_file(vec![(
"t".into(),
safetensors::Dtype::F16,
vec![256],
raw_bytes.clone(),
)]);
let st_reader = SafeTensorsReader::open(st_file.path()).unwrap();
let st_result: Vec<f16> = st_reader.read_as("t").unwrap();
let specs = vec![PtTensorSpec {
name: "t".into(),
storage_type: "HalfStorage".into(),
storage_key: "0".into(),
shape: vec![256],
stride: vec![1],
storage_offset: 0,
numel: 256,
}];
let mut storage = BTreeMap::new();
storage.insert("0".into(), raw_bytes);
let pt_file = build_pytorch_zip(&specs, &storage);
let pt_reader = PyTorchReader::open(pt_file.path()).unwrap();
let pt_result: Vec<f16> = pt_reader.read_as("t").unwrap();
assert_eq!(zt_result, data);
assert_eq!(st_result, data);
assert_eq!(pt_result, data);
}
#[test]
fn cross_i32_all_formats() {
let data = make_i32_data(128);
let raw_bytes: Vec<u8> = bytemuck::cast_slice(&data).to_vec();
let mut zt_buf = Cursor::new(Vec::new());
let mut w = Writer::new(&mut zt_buf).unwrap();
w.add("t", &[128], &data).unwrap();
w.finish().unwrap();
zt_buf.seek(SeekFrom::Start(0)).unwrap();
let zt_reader = Reader::new(&mut zt_buf).unwrap();
let zt_result: Vec<i32> = zt_reader.read_as("t").unwrap();
let st_file = build_safetensors_file(vec![(
"t".into(),
safetensors::Dtype::I32,
vec![128],
raw_bytes.clone(),
)]);
let st_reader = SafeTensorsReader::open(st_file.path()).unwrap();
let st_result: Vec<i32> = st_reader.read_as("t").unwrap();
let specs = vec![PtTensorSpec {
name: "t".into(),
storage_type: "IntStorage".into(),
storage_key: "0".into(),
shape: vec![128],
stride: vec![1],
storage_offset: 0,
numel: 128,
}];
let mut storage = BTreeMap::new();
storage.insert("0".into(), raw_bytes);
let pt_file = build_pytorch_zip(&specs, &storage);
let pt_reader = PyTorchReader::open(pt_file.path()).unwrap();
let pt_result: Vec<i32> = pt_reader.read_as("t").unwrap();
assert_eq!(zt_result, data);
assert_eq!(st_result, data);
assert_eq!(pt_result, data);
}
#[test]
fn cross_large_tensor() {
let data = make_f32_data(512 * 512);
cross_format_verify_f32("big", vec![512, 512], vec![512, 512], &data);
}
#[test]
fn cross_manifest_shapes() {
let data = make_f32_data(8 * 16);
let raw_bytes: Vec<u8> = bytemuck::cast_slice(&data).to_vec();
let mut zt_buf = Cursor::new(Vec::new());
let mut w = Writer::new(&mut zt_buf).unwrap();
w.add("t", &[8, 16], &data).unwrap();
w.finish().unwrap();
zt_buf.seek(SeekFrom::Start(0)).unwrap();
let zt_reader = Reader::new(&mut zt_buf).unwrap();
let st_file = build_safetensors_file(vec![(
"t".into(),
safetensors::Dtype::F32,
vec![8, 16],
raw_bytes.clone(),
)]);
let st_reader = SafeTensorsReader::open(st_file.path()).unwrap();
let specs = vec![PtTensorSpec {
name: "t".into(),
storage_type: "FloatStorage".into(),
storage_key: "0".into(),
shape: vec![8, 16],
stride: vec![16, 1],
storage_offset: 0,
numel: 128,
}];
let mut storage = BTreeMap::new();
storage.insert("0".into(), raw_bytes);
let pt_file = build_pytorch_zip(&specs, &storage);
let pt_reader = PyTorchReader::open(pt_file.path()).unwrap();
let zt_obj = zt_reader.get("t").unwrap();
let st_obj = st_reader.get("t").unwrap();
let pt_obj = pt_reader.get("t").unwrap();
assert_eq!(zt_obj.shape, vec![8, 16]);
assert_eq!(st_obj.shape, vec![8, 16]);
assert_eq!(pt_obj.shape, vec![8, 16]);
assert_eq!(zt_obj.components.get("data").unwrap().dtype, DType::F32);
assert_eq!(st_obj.components.get("data").unwrap().dtype, DType::F32);
assert_eq!(pt_obj.components.get("data").unwrap().dtype, DType::F32);
}
#[test]
fn cross_model_checkpoint() {
for layer_idx in 0..10 {
let w_name = format!("layer.{}.weight", layer_idx);
let b_name = format!("layer.{}.bias", layer_idx);
let w_data = make_f32_data(32 * 32);
let b_data = make_f32_data(32);
let w_bytes: Vec<u8> = bytemuck::cast_slice(&w_data).to_vec();
let b_bytes: Vec<u8> = bytemuck::cast_slice(&b_data).to_vec();
let mut zt_buf = Cursor::new(Vec::new());
let mut w = Writer::new(&mut zt_buf).unwrap();
w.add(&w_name, &[32, 32], &w_data).unwrap();
w.add(&b_name, &[32], &b_data).unwrap();
w.finish().unwrap();
zt_buf.seek(SeekFrom::Start(0)).unwrap();
let zt_reader = Reader::new(&mut zt_buf).unwrap();
let st_file = build_safetensors_file(vec![
(
w_name.clone(),
safetensors::Dtype::F32,
vec![32, 32],
w_bytes.clone(),
),
(
b_name.clone(),
safetensors::Dtype::F32,
vec![32],
b_bytes.clone(),
),
]);
let st_reader = SafeTensorsReader::open(st_file.path()).unwrap();
let specs = vec![
PtTensorSpec {
name: w_name.clone(),
storage_type: "FloatStorage".into(),
storage_key: "0".into(),
shape: vec![32, 32],
stride: vec![32, 1],
storage_offset: 0,
numel: 1024,
},
PtTensorSpec {
name: b_name.clone(),
storage_type: "FloatStorage".into(),
storage_key: "1".into(),
shape: vec![32],
stride: vec![1],
storage_offset: 0,
numel: 32,
},
];
let mut storage = BTreeMap::new();
storage.insert("0".into(), w_bytes);
storage.insert("1".into(), b_bytes);
let pt_file = build_pytorch_zip(&specs, &storage);
let pt_reader = PyTorchReader::open(pt_file.path()).unwrap();
let zt_w: Vec<f32> = zt_reader.read_as(&w_name).unwrap();
let st_w: Vec<f32> = st_reader.read_as(&w_name).unwrap();
let pt_w: Vec<f32> = pt_reader.read_as(&w_name).unwrap();
assert_eq!(zt_w, w_data, "layer {} weight ZT mismatch", layer_idx);
assert_eq!(st_w, w_data, "layer {} weight ST mismatch", layer_idx);
assert_eq!(pt_w, w_data, "layer {} weight PT mismatch", layer_idx);
let zt_b: Vec<f32> = zt_reader.read_as(&b_name).unwrap();
let st_b: Vec<f32> = st_reader.read_as(&b_name).unwrap();
let pt_b: Vec<f32> = pt_reader.read_as(&b_name).unwrap();
assert_eq!(zt_b, b_data, "layer {} bias ZT mismatch", layer_idx);
assert_eq!(st_b, b_data, "layer {} bias ST mismatch", layer_idx);
assert_eq!(pt_b, b_data, "layer {} bias PT mismatch", layer_idx);
}
}
#[test]
fn cross_mixed_dtypes() {
let f32_data = make_f32_data(64);
let i64_data = make_i64_data(32);
let u8_data = make_u8_data(128);
let f32_bytes: Vec<u8> = bytemuck::cast_slice(&f32_data).to_vec();
let i64_bytes: Vec<u8> = bytemuck::cast_slice(&i64_data).to_vec();
let mut zt_buf = Cursor::new(Vec::new());
let mut w = Writer::new(&mut zt_buf).unwrap();
w.add("f32", &[64], &f32_data).unwrap();
w.add("i64", &[32], &i64_data).unwrap();
w.add("u8", &[128], &u8_data).unwrap();
w.finish().unwrap();
zt_buf.seek(SeekFrom::Start(0)).unwrap();
let zt_reader = Reader::new(&mut zt_buf).unwrap();
let st_file = build_safetensors_file(vec![
(
"f32".into(),
safetensors::Dtype::F32,
vec![64],
f32_bytes.clone(),
),
(
"i64".into(),
safetensors::Dtype::I64,
vec![32],
i64_bytes.clone(),
),
(
"u8".into(),
safetensors::Dtype::U8,
vec![128],
u8_data.clone(),
),
]);
let st_reader = SafeTensorsReader::open(st_file.path()).unwrap();
let specs = vec![
PtTensorSpec {
name: "f32".into(),
storage_type: "FloatStorage".into(),
storage_key: "0".into(),
shape: vec![64],
stride: vec![1],
storage_offset: 0,
numel: 64,
},
PtTensorSpec {
name: "i64".into(),
storage_type: "LongStorage".into(),
storage_key: "1".into(),
shape: vec![32],
stride: vec![1],
storage_offset: 0,
numel: 32,
},
PtTensorSpec {
name: "u8".into(),
storage_type: "ByteStorage".into(),
storage_key: "2".into(),
shape: vec![128],
stride: vec![1],
storage_offset: 0,
numel: 128,
},
];
let mut storage = BTreeMap::new();
storage.insert("0".into(), f32_bytes);
storage.insert("1".into(), i64_bytes);
storage.insert("2".into(), u8_data.clone());
let pt_file = build_pytorch_zip(&specs, &storage);
let pt_reader = PyTorchReader::open(pt_file.path()).unwrap();
assert_eq!(zt_reader.read_as::<f32>("f32").unwrap(), f32_data);
assert_eq!(st_reader.read_as::<f32>("f32").unwrap(), f32_data);
assert_eq!(pt_reader.read_as::<f32>("f32").unwrap(), f32_data);
assert_eq!(zt_reader.read_as::<i64>("i64").unwrap(), i64_data);
assert_eq!(st_reader.read_as::<i64>("i64").unwrap(), i64_data);
assert_eq!(pt_reader.read_as::<i64>("i64").unwrap(), i64_data);
assert_eq!(zt_reader.read_as::<u8>("u8").unwrap(), u8_data);
assert_eq!(st_reader.read_as::<u8>("u8").unwrap(), u8_data);
assert_eq!(pt_reader.read_as::<u8>("u8").unwrap(), u8_data);
}