ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include <ctranslate2/ops/awq/gemm.h>
#include <ctranslate2/ops/sum.h>
#include "dispatch.h"

namespace ctranslate2 {
  namespace ops {

    void GemmAwq::operator()(const StorageView& a,
                          const StorageView& b,
                          const StorageView& scale,
                          const StorageView& zero,
                          StorageView& c,
                          const StorageView* bias,
                          const StorageView* residual) const {
      PROFILE("Gemm Awq");
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
      throw std::runtime_error("AWQ Gemm does not support for cuda arch < 7.5");
#else
      if (a.dtype() != DataType::FLOAT16 && b.dtype() != DataType::INT32)
        throw std::invalid_argument("Awq gemm is only supported for float16 input and int32 weight");
      if (a.device() == Device::CPU)
        throw std::invalid_argument("Awq gemm is only supported on GPU");

      DEVICE_DISPATCH(a.device(), (compute<D, float16_t, int>(a, b, scale, zero, c)));

      StorageView tmp(c.dtype(), c.device());
      ops::Sum(0)(c, tmp);
      tmp.squeeze(0);
      c = std::move(tmp);

      apply_bias_and_activation(c, bias, _activation_type, residual);
#endif
    }
  }
}