#include "ctranslate2/ops/quantize.h"
#include <cmath>
#include "cpu/kernels.h"
#include "cpu/parallel.h"
namespace ctranslate2 {
namespace ops {
template<>
void Quantize::quantize<Device::CPU, float, int8_t>(const StorageView& input,
StorageView& output,
StorageView& scale) const {
const dim_t batch_size = scale.size();
const dim_t depth = input.dim(-1);
CPU_ISA_DISPATCH(cpu::quantize_s8<ISA>(input.data<float>(),
output.data<int8_t>(),
scale.data<float>(),
batch_size,
depth,
_shift_to_uint8,
_round_before_cast));
}
template <typename RoundFunc>
static void quantize_s16_kernel(const float* x,
const float scale,
int16_t* y,
dim_t size,
const RoundFunc& round_func) {
constexpr float int16_min = std::numeric_limits<int16_t>::lowest();
constexpr float int16_max = std::numeric_limits<int16_t>::max();
cpu::parallel_unary_transform(
x, y, size, 5,
[scale, int16_min, int16_max, &round_func](float v) {
return std::max(std::min(round_func(v * scale), int16_max), int16_min);
});
}
template<>
void Quantize::quantize<Device::CPU, float, int16_t>(const StorageView& input,
StorageView& output,
StorageView& scale) const {
const dim_t size = input.size();
const auto* input_data = input.data<float>();
auto* output_data = output.data<int16_t>();
float scale_value = global_int16_scale;
if (_int16_scale_type == ScaleType::PER_LAYER) {
const float amax = primitives<Device::CPU>::amax(input_data, size);
scale_value = static_cast<float>(1 << 10) / amax;
}
scale = StorageView(scale_value);
if (_round_before_cast)
quantize_s16_kernel(input_data, scale_value, output_data, size, std::nearbyintf);
else
quantize_s16_kernel(input_data, scale_value, output_data, size, cpu::identity());
}
}
}