from copy import copy, deepcopy
from functools import partial
from typing import Callable
import numpy as np
from .. import module as Float
from ..functional import concat, norm
from ..logger import get_logger
from ..module import Module
from ..module import qat as QAT
from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from ..tensor import Tensor
from ..utils.module_utils import set_expand_structure
from .qconfig import QConfig, ema_fakequant_qconfig
logger = get_logger(__name__)
def _get_quantable_module_names():
def is_quantable(key: str):
value = getattr(Quantized, key)
return (
isinstance(value, type)
and issubclass(value, QuantizedModule)
and value != QuantizedModule
)
quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)]
return quantable_module_names
def _get_convert_dict():
quantable_module_names = _get_quantable_module_names()
quantable_modules = [getattr(Float, key) for key in quantable_module_names]
qat_modules = [getattr(QAT, key) for key in quantable_module_names]
quantized_modules = [getattr(Quantized, key) for key in quantable_module_names]
float2qat_dict = dict(zip(quantable_modules, qat_modules))
qat2quantized_dict = dict(zip(qat_modules, quantized_modules))
return float2qat_dict, qat2quantized_dict
_float2qat_dict, _qat2quantized_dict = _get_convert_dict()
qat_modules = tuple(_qat2quantized_dict.keys())
def quantize(module: Module, inplace: bool = True, mapping: dict = None):
if not inplace:
module = deepcopy(module)
convert_dict = copy(_qat2quantized_dict)
if mapping is not None:
convert_dict.update(mapping)
qat_modules = tuple(convert_dict.keys())
def is_qat(mod: Module):
return isinstance(mod, qat_modules)
for key, submodule, parent in list(
module._flatten(with_key=True, with_parent=True, predicate=is_qat)
):
new_mod = convert_dict[type(submodule)].from_qat_module(submodule)
set_expand_structure(module, key, new_mod)
return module
def quantize_qat(
module: Module,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
mapping: dict = None,
):
if not inplace:
module = deepcopy(module)
convert_dict = copy(_float2qat_dict)
if mapping is not None:
convert_dict.update(mapping)
quantable_modules = tuple(convert_dict.keys())
def is_quantable(mod: Module):
return isinstance(mod, quantable_modules)
for key, submodule, parent in list(
module._flatten(with_key=True, with_parent=True, predicate=is_quantable)
):
if is_quantable(parent) or submodule.quantize_disabled:
continue
new_mod = convert_dict[type(submodule)].from_float_module(submodule)
set_expand_structure(module, key, new_mod)
propagate_qconfig(module, qconfig)
return module
def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
if not inplace:
module = deepcopy(module)
def safe_call(func, qparams):
inst = func() if func is not None else None
if inst is not None and getattr(inst, "set_qparams", None) is not None:
inst.set_qparams(qparams)
return inst
def is_qat(mod: Module):
return isinstance(mod, QATModule)
for m in list(module._flatten(predicate=is_qat)):
if m.with_weight:
weight_params = m.get_weight_qparams()
m.weight_observer = safe_call(qconfig.weight_observer, weight_params)
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_params)
if m.with_act:
act_params = m.get_activation_qparams()
m.act_observer = safe_call(qconfig.act_observer, act_params)
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_params)
return module
def _propagate(module: Module, func_str: str, *args, **kargs):
def fn(mod: Module):
if isinstance(mod, QATModule):
getattr(mod, func_str)(*args, **kargs)
module.apply(fn)
def propagate_qconfig(module: QATModule, qconfig: QConfig):
_propagate(module, "set_qconfig", qconfig)
def hook_qat_module(module: Module, func: Callable):
def is_qat(mod: Module):
return isinstance(mod, QATModule)
hooks = []
for submodule in list(module._flatten(predicate=is_qat)):
hooks.append(submodule.register_forward_hook(func))
return hooks
def apply_easy_quant(
module: Module, data: Tensor, start: float = 0.8, stop: float = 1.2, num: int = 40
):
batch_size = data.shape[0]
def get_cosine(x, y):
ndim = len(x.shape)
axis = tuple(range(1, ndim))
up = (x * y).sum(axis=axis)
down = norm(x, axis=axis) * norm(y, axis=axis)
sim = up / down
return sim.mean(axis=0)
def search(mod, inputs, outputs, where):
mod._forward_hooks.clear()
normal_in = [_[:batch_size] for _ in inputs]
fakequant_in = [_[batch_size:] for _ in inputs]
disable_fake_quant(mod)
normal_out = mod(*normal_in)
enable_fake_quant(mod)
ob = getattr(mod, where)
if ob is None:
return
orig_scale = ob.orig_scale
cosine = optimal = 0
for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
ob.scale = scale
fakequant_out = mod(*fakequant_in)
dis = get_cosine(normal_out, fakequant_out)
if dis > cosine:
cosine = dis
optimal = scale
if optimal == 0:
logger.warning("EasyQuant finds no better scale")
else:
ob.scale = optimal
fakequant_out = outputs[batch_size:]
return concat([normal_out, fakequant_out])
data = concat([data, data])
hook_qat_module(module, partial(search, where="weight_observer"))
module(data)
hook_qat_module(module, partial(search, where="act_observer"))
module(data)
return module
def disable_fake_quant(module: Module):
_propagate(module, "set_fake_quant", False)
def disable_observer(module: Module):
_propagate(module, "set_observer", False)
def enable_fake_quant(module: Module):
_propagate(module, "set_fake_quant", True)
def enable_observer(module: Module):
_propagate(module, "set_observer", True)