ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "ctranslate2/ops/dequantize.h"

#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    Dequantize::Dequantize(const ActivationType* activation_type)
      : _activation_type(activation_type)
    {
    }

    void Dequantize::operator()(const StorageView& input,
                                const StorageView& scale,
                                StorageView& output) const {
      PROFILE("Dequantize");
      output.resize_as(input);

      switch (input.dtype()) {
      case DataType::INT16: {
        if (input.device() != Device::CPU)
          throw std::invalid_argument("INT16 dequantization is only supported on CPU");
        if (!scale.is_scalar())
          throw std::invalid_argument("INT16 quantization scale should be a scalar value");
        dequantize<Device::CPU, int16_t, float>(input, scale, output);
        break;
      }

      case DataType::INT8: {
        const dim_t batch_size = input.size() / input.dim(-1);
        if (scale.size() != batch_size)
          throw std::invalid_argument("INT8 dequantization expects per-batch scales");

        DEVICE_AND_FLOAT_DISPATCH("Dequantize", output.device(), output.dtype(),
                                  (dequantize<D, int8_t, T>(input, scale, output)));

        break;
      }

      default:
        throw std::invalid_argument("Dequantize: invalid quantized type " + dtype_name(input.dtype())
                                    + ", expected int8 or int16");
      }
    }

    void Dequantize::operator()(const StorageView& c,
                                const StorageView& a_scale,
                                const StorageView& b_scale,
                                const bool transpose_a,
                                const bool transpose_b,
                                StorageView& y,
                                const StorageView* bias) const {
      PROFILE("DequantizeGemmOutput");
      y.resize_as(c);

      DEVICE_AND_FLOAT_DISPATCH(
        "DequantizeGemmOutput", y.device(), y.dtype(),
        (dequantize_gemm_output<D, T>(c, a_scale, b_scale, transpose_a, transpose_b, bias, y)));
    }

  }
}