import io
import pickle
import struct
import tarfile
import os
from collections import OrderedDict
def create_sys_info():
sys_info = {
"protocol_version": 1000,
"little_endian": True,
"type_sizes": {
"short": 2,
"int": 4,
"long": 8,
},
}
return pickle.dumps(sys_info, protocol=2)
def encode_tensor_data(values: list, storage_type: str) -> tuple:
fmt_map = {
"FloatStorage": ("<f", 4),
"DoubleStorage": ("<d", 8),
"LongStorage": ("<q", 8),
"IntStorage": ("<i", 4),
"ShortStorage": ("<h", 2),
"ByteStorage": ("<B", 1),
"CharStorage": ("<b", 1),
"BoolStorage": ("<B", 1),
"HalfStorage": ("<e", 2),
}
fmt, size = fmt_map[storage_type]
data = b"".join(struct.pack(fmt, v) for v in values)
return data, size
def write_int(buffer, value):
if 0 <= value < 256:
buffer.write(b'K') buffer.write(bytes([value]))
elif 0 <= value < 65536:
buffer.write(b'M') buffer.write(struct.pack('<H', value))
else:
buffer.write(b'J') buffer.write(struct.pack('<i', value))
def write_string(buffer, s):
s_bytes = s.encode('utf-8')
if len(s_bytes) < 256:
buffer.write(b'U') buffer.write(bytes([len(s_bytes)]))
buffer.write(s_bytes)
else:
buffer.write(b'T') buffer.write(struct.pack('<I', len(s_bytes)))
buffer.write(s_bytes)
def create_storages_blob_manual(tensors: list) -> bytes:
buffer = io.BytesIO()
pickle.dump(len(tensors), buffer, protocol=2)
for key, storage_type, element_size, data_bytes in tensors:
tuple_buffer = io.BytesIO()
tuple_buffer.write(b'\x80\x02')
tuple_buffer.write(b'(')
write_string(tuple_buffer, key)
tuple_buffer.write(b'U\x03cpu')
tuple_buffer.write(b'c') tuple_buffer.write(b'torch\n') tuple_buffer.write(storage_type.encode('ascii') + b'\n')
tuple_buffer.write(b't') tuple_buffer.write(b'.')
buffer.write(tuple_buffer.getvalue())
num_elements = len(data_bytes) // element_size
buffer.write(struct.pack("<Q", num_elements))
buffer.write(data_bytes)
return buffer.getvalue()
def create_main_pickle_manual(tensors_info: list) -> bytes:
buffer = io.BytesIO()
buffer.write(b'\x80\x02')
buffer.write(b'ccollections\nOrderedDict\n')
buffer.write(b'(') buffer.write(b']')
for name, storage_key, storage_type, shape, num_elements in tensors_info:
stride = []
s = 1
for dim in reversed(shape):
stride.insert(0, s)
s *= dim
buffer.write(b'(')
write_string(buffer, name)
buffer.write(b'ctorch._utils\n_rebuild_tensor_v2\n')
buffer.write(b'(')
buffer.write(b'(')
write_string(buffer, 'storage')
buffer.write(b'c')
buffer.write(b'torch\n')
buffer.write(storage_type.encode('ascii') + b'\n')
write_string(buffer, storage_key)
buffer.write(b'U\x03cpu')
write_int(buffer, num_elements)
buffer.write(b't')
buffer.write(b'K\x00')
buffer.write(b'(')
for dim in shape:
write_int(buffer, dim)
buffer.write(b't')
buffer.write(b'(')
for s_val in stride:
write_int(buffer, s_val)
buffer.write(b't')
buffer.write(b'\x89')
buffer.write(b'ccollections\nOrderedDict\n')
buffer.write(b'(')
buffer.write(b']')
buffer.write(b't')
buffer.write(b'R')
buffer.write(b't')
buffer.write(b'R')
buffer.write(b't')
buffer.write(b'a')
buffer.write(b't') buffer.write(b'R') buffer.write(b'.')
return buffer.getvalue()
def create_tar_pytorch_file(filename: str, tensors: dict, dtypes: dict):
storage_list = [] tensors_info = []
for idx, (name, (values, shape)) in enumerate(tensors.items()):
storage_key = str(idx)
storage_type = dtypes[name]
data_bytes, element_size = encode_tensor_data(values, storage_type)
num_elements = len(values)
storage_list.append((storage_key, storage_type, element_size, data_bytes))
tensors_info.append((name, storage_key, storage_type, shape, num_elements))
sys_info_data = create_sys_info()
pickle_data = create_main_pickle_manual(tensors_info)
storages_data = create_storages_blob_manual(storage_list)
os.makedirs(os.path.dirname(filename) or ".", exist_ok=True)
with tarfile.open(filename, "w") as tar:
tarinfo = tarfile.TarInfo(name="sys_info")
tarinfo.size = len(sys_info_data)
tar.addfile(tarinfo, io.BytesIO(sys_info_data))
tarinfo = tarfile.TarInfo(name="pickle")
tarinfo.size = len(pickle_data)
tar.addfile(tarinfo, io.BytesIO(pickle_data))
tarinfo = tarfile.TarInfo(name="storages")
tarinfo.size = len(storages_data)
tar.addfile(tarinfo, io.BytesIO(storages_data))
size = os.path.getsize(filename)
print(f"Created {filename} ({size} bytes)")
print(f" Tensors: {list(tensors.keys())}")
def main():
os.makedirs("test_data", exist_ok=True)
create_tar_pytorch_file(
"test_data/tar_float32.tar",
{"tensor": ([1.0, 2.5, -3.7, 0.0], [4])},
{"tensor": "FloatStorage"},
)
create_tar_pytorch_file(
"test_data/tar_float64.tar",
{"tensor": ([1.1, 2.2, 3.3], [3])},
{"tensor": "DoubleStorage"},
)
create_tar_pytorch_file(
"test_data/tar_int64.tar",
{"tensor": ([100, -200, 300, 0], [4])},
{"tensor": "LongStorage"},
)
create_tar_pytorch_file(
"test_data/tar_weight_bias.tar",
{
"weight": ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [2, 3]),
"bias": ([0.01, 0.02], [2]),
},
{
"weight": "FloatStorage",
"bias": "FloatStorage",
},
)
create_tar_pytorch_file(
"test_data/tar_multi_dtype.tar",
{
"float_tensor": ([1.5, 2.5, 3.5], [3]),
"double_tensor": ([1.111, 2.222], [2]),
"int_tensor": ([10, 20, 30, 40], [4]),
},
{
"float_tensor": "FloatStorage",
"double_tensor": "DoubleStorage",
"int_tensor": "LongStorage",
},
)
create_tar_pytorch_file(
"test_data/tar_2d_tensor.tar",
{
"matrix": ([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], [3, 4]),
},
{"matrix": "FloatStorage"},
)
print("\nAll TAR format test files created!")
print("\nTo run tests: cargo test -p burn-store --features pytorch test_tar")
if __name__ == "__main__":
main()