import argparse
import glob
import logging
import subprocess
import sys
import zipfile
from pathlib import Path
from typing import Dict
import os
import numpy as np
import torch
from numpy.lib.format import write_array
from torch import Tensor
def zipfile_factory(file, *args, **kwargs):
if not hasattr(file, 'read'):
file = os.fspath(file)
import zipfile
kwargs['allowZip64'] = True
kwargs['compresslevel'] = 4
return zipfile.ZipFile(file, *args, **kwargs)
def get_bf16_repr(input_tensor: torch.Tensor) -> np.ndarray:
v_fp32 = input_tensor.cpu().float().numpy()
byte_array = np.frombuffer(v_fp32.tobytes(), dtype=np.uint32)
nan_value = np.logical_or(np.right_shift(byte_array, 16), 0x0040)
nan_mask = np.logical_and(byte_array, 0x7FFF_FFFF) > 0x7F80_0000
round_bit = 0x0000_8000
output_val = np.right_shift(byte_array, 16)
threshold_mask = (np.logical_and(byte_array, round_bit) != 0) & (
np.logical_and(byte_array, (3 * round_bit - 1)) != 0
)
output = np.where(
nan_mask, nan_value, np.where(threshold_mask, output_val + 1, output_val)
).astype(np.uint16)
return output
def append_to_zipf(
array_dict: Dict[str, np.ndarray], parent_zipfile: zipfile.ZipFile
) -> None:
for key, array in array_dict.items():
internal_filename = key + ".npy"
array = np.asanyarray(array)
with parent_zipfile.open(internal_filename, "w", force_zip64=True) as f_in:
write_array(f_in, array, allow_pickle=True, pickle_kwargs=None)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"source_file",
nargs="+",
help="""Absolute path (or file pattern) to the Pytorch weights file(s) to convert.
A single file, list of files, glob pattern or list of glob patterns can be provided.""",
)
parser.add_argument(
"--skip_embeddings",
action="store_true",
help="Skip shared embeddings",
)
parser.add_argument(
"--skip_lm_head", action="store_true", help="Skip language model head"
)
parser.add_argument("--prefix", help="Add a prefix on weight names")
parser.add_argument(
"--suffix",
action="store_true",
help="Split weight names on '.' and keep only last part",
)
parser.add_argument(
"--dtype",
help="Convert weights to a specific numpy DataType (float32, float16, ...)",
)
parser.add_argument(
"--download_libtorch",
action="store_true",
help="Use this flag to enable automatic download of the libtorch library.",
)
args = parser.parse_args()
logger = logging.getLogger('convert_model')
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler('convert_model.log')
fh.setLevel(logging.DEBUG)
logger.addHandler(fh)
target_folder = Path(args.source_file[0]).parent
with zipfile_factory(
target_folder / "model.npz", mode="w", compression=False
) as output_zipfile:
for source_file_or_pattern in args.source_file:
source_files = glob.glob(source_file_or_pattern)
for source_file in source_files:
logger.info(f"Processing source file {source_file}")
nps = {}
source_file = Path(source_file)
weights = torch.load(str(source_file), map_location="cpu")
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
if args.skip_embeddings:
if k in {
"model.encoder.embed_tokens.weight",
"encoder.embed_tokens.weight",
"model.decoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
}:
continue
if args.skip_lm_head:
if k in {
"lm_head.weight",
}:
continue
if args.prefix:
k = args.prefix + k
if args.suffix:
k = k.split(".")[-1]
if isinstance(v, Tensor):
if v.dtype == torch.bfloat16:
tensor = get_bf16_repr(v)
else:
tensor = v.cpu().numpy()
if args.dtype is not None:
nps[k] = np.ascontiguousarray(
tensor.astype(np.dtype(args.dtype))
)
else:
nps[k] = np.ascontiguousarray(tensor)
logger.info(
f"converted {k} - {str(sys.getsizeof(nps[k]))} bytes"
)
else:
logger.info(f"skipped non-tensor object: {k}")
append_to_zipf(nps, output_zipfile)
source = str(target_folder / "model.npz")
target = str(target_folder / "rust_model.ot")
toml_location = (Path(__file__).resolve() / ".." / ".." / "Cargo.toml").resolve()
cargo_args = [
"cargo",
"run",
"--bin=convert-tensor",
"--manifest-path=%s" % toml_location,
"--",
source,
target,
]
if args.download_libtorch:
cargo_args += ["--features", "download-libtorch"]
subprocess.run(cargo_args)