ct2rs 0.9.19

Rust bindings for OpenNMT/CTranslate2
Documentation
#include "backend.h"

#include "ctranslate2/utils.h"
#include "cpu_info.h"
#include "env.h"

namespace ctranslate2 {
  namespace cpu {

    static bool mayiuse_mkl_init() {
      const std::string use_mkl_env = read_string_from_env("CT2_USE_MKL");
      if (use_mkl_env.empty()) {
#ifdef CT2_WITH_MKL
        return cpu_is_genuine_intel();
#else
        return false;
#endif
      } else {
        const bool use_mkl = string_to_bool(use_mkl_env);
#ifndef CT2_WITH_MKL
        if (use_mkl)
          throw std::invalid_argument("This CTranslate2 binary was not compiled with Intel MKL");
#endif
        return use_mkl;
      }
    }

    bool mayiuse_mkl() {
      static const bool mayiuse = mayiuse_mkl_init();
      return mayiuse;
    }

    std::string gemm_backend_to_str(GemmBackend gemm_backend) {
      switch (gemm_backend) {
      case GemmBackend::MKL:
        return "MKL";
      case GemmBackend::DNNL:
        return "DNNL";
      case GemmBackend::ACCELERATE:
        return "Accelerate";
      case GemmBackend::OPENBLAS:
        return "OpenBLAS";
      case GemmBackend::RUY:
        return "Ruy";
      default:
        return "none";
      }
    }

    GemmBackend get_gemm_backend(ComputeType compute_type) {
      const bool is_int8 = (compute_type == ComputeType::INT8
                            || compute_type == ComputeType::INT8_FLOAT32
                            || compute_type == ComputeType::INT8_FLOAT16
                            || compute_type == ComputeType::INT8_BFLOAT16);

#ifdef CT2_WITH_MKL
      if (mayiuse_mkl()
          && (compute_type == ComputeType::FLOAT32
              || compute_type == ComputeType::INT16
              || is_int8)) {
        return GemmBackend::MKL;
      }
#endif

#ifdef CT2_WITH_DNNL
      if (compute_type == ComputeType::FLOAT32 || is_int8) {
        return GemmBackend::DNNL;
      }
#endif

#ifdef CT2_WITH_ACCELERATE
      if (compute_type == ComputeType::FLOAT32) {
        return GemmBackend::ACCELERATE;
      }
#endif

#ifdef CT2_WITH_OPENBLAS
      if (compute_type == ComputeType::FLOAT32) {
        return GemmBackend::OPENBLAS;
      }
#endif

#ifdef CT2_WITH_RUY
      if (is_int8 || compute_type == ComputeType::FLOAT32) {
        return GemmBackend::RUY;
      }
#endif

      return GemmBackend::NONE;
    }

    bool has_gemm_backend(ComputeType compute_type) {
      return get_gemm_backend(compute_type) != GemmBackend::NONE;
    }

    bool prefer_u8s8s32_gemm() {
      const auto gemm_s8_backend = get_gemm_backend(ComputeType::INT8);
      return gemm_s8_backend == cpu::GemmBackend::MKL || gemm_s8_backend == cpu::GemmBackend::DNNL;
    }

    bool pack_gemm_weights(ComputeType compute_type) {
      static const bool should_pack = read_bool_from_env("CT2_USE_EXPERIMENTAL_PACKED_GEMM");
      return should_pack && get_gemm_backend(compute_type) == GemmBackend::MKL;
    }

#ifdef CT2_WITH_RUY
    ruy::Context *get_ruy_context() {
      static thread_local ruy::Context context;
      return &context;
    }
#endif
  }
}