import abc
import ctypes
import json
import os
import shutil
import struct
from typing import Dict, List, Optional
import numpy as np
try:
import torch
torch_is_available = True
except ImportError:
torch_is_available = False
OPTIONAL = "__optional"
CURRENT_BINARY_VERSION = 6
ACCEPTED_MODEL_TYPES = (
"int8",
"int8_float32",
"int8_float16",
"int8_bfloat16",
"int16",
"float16",
"bfloat16",
"float32",
)
SKIP_CREATING_ALIAS = ("rotary_scaling_long_factor", "rotary_scaling_short_factor")
def _join_scope(scope, name):
if not scope:
return name
return "%s/%s" % (scope, name)
def _split_scope(scope):
return scope.split("/")
def _parent_scope(scope):
keys = _split_scope(scope)
scope, attr = keys[:-1], keys[-1]
return "/".join(scope), attr
def visit_spec(spec, fn, scope=""):
for name, value in list(spec.__dict__.items()):
if name.startswith("_"):
continue
if isinstance(value, list):
for i, elem in enumerate(value):
visit_spec(elem, fn, scope=_join_scope(scope, "%s_%d" % (name, i)))
elif isinstance(value, LayerSpec):
visit_spec(value, fn, scope=_join_scope(scope, name))
else:
fn(spec, _join_scope(scope, name), value)
def index_spec(spec, index):
if not index:
return spec
keys = _split_scope(index)
for key in keys:
try:
spec = getattr(spec, key)
except AttributeError:
attr, index = key.rsplit("_", 1)
spec = getattr(spec, attr)[int(index)]
return spec
class FrozenMeta(type):
def __call__(self, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
instance._frozen = True
return instance
class FrozenAttr:
def __setattr__(self, key, value):
if hasattr(self, "_frozen") and not hasattr(self, key):
raise AttributeError("Attribute %s does not exist" % key)
super().__setattr__(key, value)
class LayerSpec(FrozenAttr, metaclass=FrozenMeta):
def validate(self) -> None:
unset_attributes = []
def _check(spec, name, value):
if value is None:
unset_attributes.append(name)
return
if isinstance(value, np.ndarray):
if value.dtype == np.float64:
value = value.astype(np.float32)
elif isinstance(value, float):
value = np.dtype("float32").type(value)
elif isinstance(value, bool):
value = np.dtype("int8").type(value)
elif isinstance(value, str):
if value != OPTIONAL:
value = np.frombuffer(value.encode("utf-8"), dtype=np.int8)
if isinstance(value, np.ndarray) or isinstance(value, np.generic):
value = NumpyVariable(value)
elif torch_is_available and isinstance(value, torch.Tensor):
value = PyTorchVariable(value)
attr_name = _split_scope(name)[-1]
setattr(spec, attr_name, value)
self._visit(_check)
if unset_attributes:
raise ValueError(
"Some required model attributes are not set:\n\n%s"
% "\n".join(unset_attributes)
)
def variables(
self,
prefix: str = "",
ordered: bool = False,
) -> Dict[str, np.ndarray]:
var = {}
def _register_var(spec, name, value):
if isinstance(value, str) and value == OPTIONAL:
return
var[_join_scope(prefix, name)] = value
self._visit(_register_var)
if ordered:
return list(sorted(var.items(), key=lambda x: x[0]))
return var
def _alias_variables(self):
variables = self.variables(ordered=True)
for name, value in reversed(variables):
for other_name, other_value in variables:
if name == other_name:
break
scope, attr_name = _parent_scope(name)
if (
not value.is_scalar()
and value.equal(other_value)
and attr_name not in SKIP_CREATING_ALIAS
):
spec = index_spec(self, scope)
setattr(spec, attr_name, other_name)
break
def _quantize(self, quantization):
if quantization is not None and quantization not in ACCEPTED_MODEL_TYPES:
raise ValueError(
"%s is not a valid quantization type. Accepted types are: %s"
% (quantization, ", ".join(ACCEPTED_MODEL_TYPES))
)
def _quantize(spec, name, value):
if not isinstance(value, Variable) or value.is_scalar():
return
key = _split_scope(name)[-1]
scale = None
is_quantizable = hasattr(spec, "%s_scale" % key)
is_convertible = value.dtype in ("float32", "float16", "bfloat16")
if is_quantizable:
if quantization == "int16":
value = value.to("float32").numpy()
scale = np.float32(2**10 / np.amax(np.absolute(value)))
value *= scale
value = np.rint(value)
value = np.clip(
value, np.iinfo(np.int16).min, np.iinfo(np.int16).max
)
value = value.astype(np.int16)
scale = NumpyVariable(scale)
value = NumpyVariable(value)
elif quantization in (
"int8",
"int8_float32",
"int8_float16",
"int8_bfloat16",
):
value = value.to("float32").numpy()
old_shape = None
if len(value.shape) == 3:
old_shape = value.shape
value = value.reshape(value.shape[0], -1)
amax = np.amax(np.absolute(value), axis=1)
amax[amax == 0] = 127.0
scale = 127.0 / amax
value *= np.expand_dims(scale, 1)
value = np.rint(value)
value = value.astype(np.int8)
if old_shape:
value = value.reshape(old_shape)
scale = NumpyVariable(scale)
value = NumpyVariable(value)
elif quantization in ("float16", "bfloat16", "float32"):
value = value.to(quantization)
elif is_convertible:
if quantization in ("float16", "int8_float16"):
value = value.to("float16")
elif quantization in ("bfloat16", "int8_bfloat16"):
value = value.to("bfloat16")
elif quantization in ("float32", "int16", "int8_float32"):
value = value.to("float32")
setattr(spec, key, value)
if scale is not None:
setattr(spec, "%s_scale" % key, scale)
self._visit(_quantize)
def optimize(self, quantization: Optional[str] = None) -> None:
self._alias_variables()
self._quantize(quantization)
def _visit(self, fn):
visit_spec(self, fn)
def _dtype_to_type_id(object_dtype):
dtypes = ("float32", "int8", "int16", "int32", "float16", "bfloat16")
try:
return dtypes.index(object_dtype)
except ValueError:
raise ValueError(
"%s is not in list of supported dtypes: %s"
% (object_dtype, ", ".join(dtypes))
)
class ModelConfig(FrozenAttr, metaclass=FrozenMeta):
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def to_dict(self):
return {
key: value
for key, value in self.__dict__.items()
if not key.startswith("_")
}
def add_attribute(self, key, value):
self.__dict__[key] = value
def save_as_json(self, path):
with open(path, "w", encoding="utf-8") as config_file:
json.dump(
self.to_dict(),
config_file,
indent=2,
sort_keys=True,
)
config_file.write("\n")
class ModelSpec(LayerSpec):
def __init__(self):
self._config = self.get_default_config()
self._files = {}
@property
def name(self):
raise NotImplementedError()
@property
def revision(self):
return 1
@property
def config(self):
return self._config
def get_default_config(self):
return None
def register_file(self, path: str, filename: Optional[str] = None) -> None:
if not os.path.isfile(path):
raise ValueError("File %s does not exist" % path)
if filename is None:
filename = os.path.basename(path)
if filename in self._files:
raise ValueError("A file with name %s was already registered" % filename)
self._files[filename] = path
def save(self, output_dir: str) -> None:
self._serialize(os.path.join(output_dir, "model.bin"))
if self._config is not None:
self._config.save_as_json(os.path.join(output_dir, "config.json"))
for filename, path in self._files.items():
destination = os.path.join(output_dir, filename)
if os.path.exists(destination):
raise RuntimeError(
"File %s already exists in the model directory" % destination
)
shutil.copy(path, destination)
def _serialize(self, path):
variables = []
aliases = []
for variable in self.variables(ordered=True):
if isinstance(variable[1], str):
aliases.append(variable)
else:
variables.append(variable)
with open(path, "wb") as model:
def _write_string(string):
model.write(struct.pack("H", len(string) + 1))
model.write(string.encode("utf-8"))
model.write(struct.pack("B", 0))
model.write(struct.pack("I", CURRENT_BINARY_VERSION))
_write_string(self.name)
model.write(struct.pack("I", self.revision))
model.write(struct.pack("I", len(variables)))
for name, value in variables:
_write_string(name)
model.write(struct.pack("B", len(value.shape)))
for dim in value.shape:
model.write(struct.pack("I", dim))
model.write(struct.pack("B", _dtype_to_type_id(value.dtype)))
model.write(struct.pack("I", value.num_bytes()))
model.write(value.to_bytes())
model.write(struct.pack("I", len(aliases)))
for alias, variable_name in aliases:
_write_string(alias)
_write_string(variable_name)
def _flatten_vocabularies(vocabularies):
for name, vocabulary in vocabularies.items():
if len(vocabulary) == 1:
yield name, vocabulary[0]
else:
for i, vocab in enumerate(vocabulary):
yield "%s_%d" % (name, i + 1), vocab
class SequenceToSequenceModelConfig(ModelConfig):
def __init__(
self,
unk_token: str = "<unk>",
bos_token: str = "<s>",
eos_token: str = "</s>",
decoder_start_token: Optional[str] = "<s>",
add_source_bos: bool = False,
add_source_eos: bool = False,
**kwargs,
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
decoder_start_token=decoder_start_token,
add_source_bos=add_source_bos,
add_source_eos=add_source_eos,
**kwargs,
)
class SequenceToSequenceModelSpec(ModelSpec):
def __init__(self):
super().__init__()
self._vocabularies = {
"source": [],
"target": [],
}
def get_default_config(self):
return SequenceToSequenceModelConfig()
@abc.abstractmethod
def get_source_vocabulary_size(self):
raise NotImplementedError()
@abc.abstractmethod
def get_target_vocabulary_size(self):
raise NotImplementedError()
def register_source_vocabulary(self, tokens: List[str]) -> None:
self._vocabularies["source"].append(tokens)
def register_target_vocabulary(self, tokens: List[str]) -> None:
self._vocabularies["target"].append(tokens)
def register_vocabulary_mapping(self, path: str) -> None:
self.register_file(path, "vmap.txt")
def validate(self) -> None:
super().validate()
vocabulary_sizes = {
"source": self.get_source_vocabulary_size(),
"target": self.get_target_vocabulary_size(),
}
for name, sizes in vocabulary_sizes.items():
if not isinstance(sizes, list):
sizes = [sizes]
vocabularies = self._vocabularies[name]
if len(vocabularies) != len(sizes):
raise ValueError(
"Incorrect number of %s vocabularies: %d registered, but expected %d"
% (name, len(vocabularies), len(sizes))
)
for i, (vocabulary, expected_size) in enumerate(zip(vocabularies, sizes)):
if len(vocabulary) != expected_size:
raise ValueError(
"%s vocabulary %d has size %d but the model expected a vocabulary "
"of size %d"
% (name.capitalize(), i, len(vocabulary), expected_size)
)
def save(self, output_dir: str) -> None:
vocabularies = dict(_flatten_vocabularies(self._vocabularies))
all_vocabularies = list(vocabularies.values())
if all(vocabulary == all_vocabularies[0] for vocabulary in all_vocabularies):
vocabularies = {"shared": all_vocabularies[0]}
for name, tokens in vocabularies.items():
_save_vocabulary(output_dir, "%s_vocabulary" % name, tokens)
super().save(output_dir)
class LanguageModelConfig(ModelConfig):
def __init__(
self,
unk_token: str = "<unk>",
bos_token: str = "<s>",
eos_token: str = "</s>",
**kwargs,
):
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
**kwargs,
)
class LanguageModelSpec(ModelSpec):
def __init__(self):
super().__init__()
self._vocabulary = []
def get_default_config(self):
return LanguageModelConfig()
@abc.abstractmethod
def get_vocabulary_size(self):
raise NotImplementedError()
def register_vocabulary(self, tokens: List[str]) -> None:
self._vocabulary = list(tokens)
def validate(self) -> None:
super().validate()
expected_vocabulary_size = self.get_vocabulary_size()
if len(self._vocabulary) != expected_vocabulary_size:
raise ValueError(
"Vocabulary has size %d but the model expected a vocabulary of size %d"
% (len(self._vocabulary), expected_vocabulary_size)
)
def save(self, output_dir: str) -> None:
_save_vocabulary(output_dir, "vocabulary", self._vocabulary)
super().save(output_dir)
def _save_vocabulary(output_dir, name, tokens):
vocabulary_path = os.path.join(output_dir, "%s.json" % name)
with open(vocabulary_path, "w", encoding="utf-8") as vocabulary_file:
json.dump(tokens, vocabulary_file, indent=2)
class Variable(abc.ABC):
@property
@abc.abstractmethod
def shape(self) -> List[int]:
raise NotImplementedError()
def is_scalar(self) -> bool:
return len(self.shape) == 0
@property
@abc.abstractmethod
def dtype(self) -> str:
raise NotImplementedError()
def to(self, dtype: str) -> "Variable":
if dtype == self.dtype:
return self
return self._to(dtype)
@abc.abstractmethod
def numpy(self) -> np.ndarray:
raise NotImplementedError()
def equal(self, other) -> bool:
return type(self) is type(other) and self._equal(other)
@abc.abstractmethod
def num_bytes(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
def to_bytes(self) -> bytes:
raise NotImplementedError()
@abc.abstractmethod
def _to(self, dtype: str) -> "Variable":
raise NotImplementedError()
@abc.abstractmethod
def _equal(self, other) -> bool:
raise NotImplementedError()
class NumpyVariable(Variable):
def __init__(self, array):
self.array = array
@property
def shape(self) -> List[int]:
return self.array.shape
@property
def dtype(self) -> str:
return self.array.dtype.name
def numpy(self) -> np.ndarray:
return self.array
def num_bytes(self) -> int:
return self.array.nbytes
def to_bytes(self) -> bytes:
return self.array.tobytes()
def _to(self, dtype: str) -> Variable:
if dtype == "bfloat16":
if not torch_is_available:
raise RuntimeError(
"Converting to bfloat16 requires torch to be installed"
)
return PyTorchVariable.from_numpy(self.array).to(dtype)
dtype = np.dtype(dtype)
self.array = self.array.astype(dtype)
return self
def _equal(self, other) -> bool:
a = self.array
b = other.array
return a is b or (
a.dtype == b.dtype
and a.shape == b.shape
and a.flat[0] == b.flat[0]
and np.array_equal(a, b)
)
class PyTorchVariable(Variable):
def __init__(self, tensor):
if isinstance(tensor, torch.nn.Parameter):
tensor = tensor.data
self.tensor = tensor.contiguous()
@classmethod
def from_numpy(cls, array):
tensor = torch.from_numpy(array)
return cls(tensor)
@property
def shape(self) -> List[int]:
return list(self.tensor.shape)
@property
def dtype(self) -> str:
return str(self.tensor.dtype).replace("torch.", "")
def numpy(self) -> np.ndarray:
return self.tensor.detach().numpy()
def num_bytes(self) -> int:
return self.tensor.numel() * self.tensor.element_size()
def to_bytes(self) -> bytes:
max_size = 2**31 - 1
num_bytes = self.num_bytes()
output = b""
offset = 0
while num_bytes > 0:
chunk_size = max_size if num_bytes > max_size else num_bytes
chunk = ctypes.string_at(self.tensor.data_ptr() + offset, chunk_size)
output += chunk
offset += chunk_size
num_bytes -= chunk_size
return output
def _to(self, dtype: str) -> Variable:
dtype = getattr(torch, dtype)
self.tensor = self.tensor.to(dtype)
return self
def _equal(self, other) -> bool:
a = self.tensor
b = other.tensor
return a is b or (a.dtype == b.dtype and torch.equal(a, b))