#include "ctranslate2/ops/conv1d.h"
#ifdef CT2_WITH_DNNL
# include <dnnl.hpp>
namespace ctranslate2 {
namespace ops {
template<>
void Conv1D::compute<Device::CPU, float>(const StorageView& input,
const StorageView& weight,
const StorageView* bias,
StorageView& output,
const StorageView* qscale) const {
if (qscale)
throw std::runtime_error("Quantization is not supported in this Conv1D implementation");
dnnl::engine engine(dnnl::engine::kind::cpu, 0);
dnnl::stream engine_stream(engine);
dnnl::memory::dims input_dims(input.shape().begin(), input.shape().end());
dnnl::memory::dims output_dims(output.shape().begin(), output.shape().end());
dnnl::memory::dims weight_dims{_groups, weight.dim(0) / _groups, weight.dim(1), weight.dim(2)};
using tag = dnnl::memory::format_tag;
using dt = dnnl::memory::data_type;
dnnl::memory::desc input_md(input_dims, dt::f32, tag::any);
dnnl::memory::desc output_md(output_dims, dt::f32, tag::any);
dnnl::memory::desc weight_md(weight_dims, dt::f32, tag::any);
dnnl::memory input_mem({input_dims, dt::f32, tag::ncw}, engine,
const_cast<void*>(input.buffer()));
dnnl::memory output_mem({output_dims, dt::f32, tag::ncw}, engine,
output.buffer());
dnnl::memory weight_mem({weight_dims, dt::f32, tag::goiw}, engine,
const_cast<void*>(weight.buffer()));
dnnl::memory::dims stride{_stride};
dnnl::memory::dims dilation{_dilation > 1 ? _dilation : 0};
dnnl::memory::dims padding{_padding};
std::unique_ptr<dnnl::convolution_forward::primitive_desc> conv_pd;
std::unordered_map<int, dnnl::memory> args;
args.reserve(4);
if (bias) {
dnnl::memory::dims bias_dims(bias->shape().begin(), bias->shape().end());
dnnl::memory::desc bias_md(bias_dims, dt::f32, tag::a);
dnnl::memory bias_mem(bias_md, engine, const_cast<void*>(bias->buffer()));
args.emplace(DNNL_ARG_BIAS, bias_mem);
conv_pd = std::make_unique<dnnl::convolution_forward::primitive_desc>(
engine,
dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_direct,
input_md,
weight_md,
bias_md,
output_md,
stride,
dilation,
padding,
padding);
} else {
conv_pd = std::make_unique<dnnl::convolution_forward::primitive_desc>(
engine,
dnnl::prop_kind::forward_inference,
dnnl::algorithm::convolution_direct,
input_md,
weight_md,
output_md,
stride,
dilation,
padding,
padding);
}
dnnl::memory conv_input_mem = input_mem;
dnnl::memory conv_weight_mem = weight_mem;
dnnl::memory conv_output_mem = output_mem;
if (conv_pd->src_desc() != input_mem.get_desc()) {
conv_input_mem = dnnl::memory(conv_pd->src_desc(), engine);
dnnl::reorder(input_mem, conv_input_mem)
.execute(engine_stream, input_mem, conv_input_mem);
}
if (conv_pd->weights_desc() != weight_mem.get_desc()) {
conv_weight_mem = dnnl::memory(conv_pd->weights_desc(), engine);
dnnl::reorder(weight_mem, conv_weight_mem)
.execute(engine_stream, weight_mem, conv_weight_mem);
}
if (conv_pd->dst_desc() != output_mem.get_desc()) {
conv_output_mem = dnnl::memory(conv_pd->dst_desc(), engine);
}
args.emplace(DNNL_ARG_SRC, conv_input_mem);
args.emplace(DNNL_ARG_WEIGHTS, conv_weight_mem);
args.emplace(DNNL_ARG_DST, conv_output_mem);
dnnl::convolution_forward conv(*conv_pd);
conv.execute(engine_stream, args);
if (conv_pd->dst_desc() != output_mem.get_desc()) {
dnnl::reorder(conv_output_mem, output_mem)
.execute(engine_stream, conv_output_mem, output_mem);
}
engine_stream.wait();
if (_activation_type)
get_activation_op(*_activation_type)(output, output);
}
}
}
#else
# include "ctranslate2/ops/gemm.h"
# include "cpu/parallel.h"
# include "ctranslate2/ops/quantize.h"
# include "ctranslate2/ops/dequantize.h"
namespace ctranslate2 {
namespace ops {
template<>
void
Conv1D::compute<Device::CPU, float>(const StorageView &input, const StorageView &weight, const StorageView *bias,
StorageView &output, const StorageView *qscale) const {
if (_dilation != 1)
throw std::runtime_error("Dilation is not supported in this Conv1D implementation");
compute_with_gemm(input, weight, output, qscale);
apply_bias_and_activation(output, bias, _activation_type, nullptr, -2);
}
void Conv1D::compute_with_gemm(const StorageView &input, const StorageView &weight, StorageView &output,
const StorageView *qscale) const {
const dim_t batch_size = input.dim(0);
const dim_t in_channels = input.dim(1);
const dim_t out_channels = weight.dim(0);
const dim_t kernel_size = weight.dim(2);
const dim_t output_length = output.dim(2);
const dim_t in_channels_per_group = in_channels / _groups;
StorageView im2col_output({batch_size, _groups, output_length, in_channels_per_group * kernel_size}, 0.0f, weight.device());
im2col_transposed(input, im2col_output, kernel_size);
const dim_t m = out_channels / _groups;
const dim_t n = output_length;
const dim_t k = in_channels_per_group * kernel_size;
const dim_t stridew = out_channels / _groups * in_channels_per_group * kernel_size * weight.item_size();
const dim_t strideb = k * output_length;
const dim_t stridec = m * output_length;
const dim_t qscale_stride = qscale ? qscale->dim(0) / _groups : 0;
auto* w = static_cast<int8_t*>(const_cast<void*>(weight.buffer()));
auto* b = im2col_output.data<float>();
auto* c = output.data<float>();
const Gemm gemm(1.0, 0.0, false, true);
const Quantize quantize_op(Quantize::ScaleType::PER_LAYER,
false,
true);
const Dequantize dequantize_op;
const auto device = im2col_output.device();
cpu::parallel_for(0, batch_size * _groups, 1, [&](dim_t begin, dim_t end) {
StorageView qinput(weight.dtype(), device);
StorageView qinput_scale(device);
if (qscale)
qinput_scale.to(qscale->dtype());
StorageView qoutput(DataType::INT32, device);
for (dim_t i = begin; i < end; ++i) {
auto group_index = i % _groups;
void* w_i = w + (group_index * stridew);
float* b_i = b + (i * strideb);
float* c_i = c + (i * stridec);
StorageView aa(weight.dtype(), weight.device());
aa.view(w_i, {m, k});
StorageView bb({n, k}, b_i); StorageView cc({m, n}, c_i);
if (qscale) {
StorageView group_qscale({qscale_stride}, const_cast<float *>(qscale->data<float>()) + group_index * qscale_stride);
quantize_op(bb, qinput, qinput_scale);
gemm(aa, qinput, qoutput);
dequantize_op(qoutput,
group_qscale,
qinput_scale,
false,
true,
cc);
} else {
gemm(aa, bb, cc);
}
}
});
}
void Conv1D::im2col_transposed(const StorageView& input, StorageView& output, const dim_t kernel_size) const {
const dim_t batch_size = input.dim(0);
const dim_t in_channels = input.dim(1);
const dim_t input_length = input.dim(2);
auto* out = output.data <float>();
const auto* in = input.data <float>();
dim_t out_offset = 0;
const auto in_batch_stride = in_channels * input_length;
const dim_t in_channels_per_group = in_channels / _groups;
const dim_t in_group_stride = in_channels_per_group * input_length;
for (dim_t batch_offset = 0; batch_offset < batch_size * in_batch_stride; batch_offset += in_batch_stride) {
for (dim_t group_offset = batch_offset; group_offset < (batch_offset + _groups * in_group_stride); group_offset += in_group_stride) {
for (dim_t ti = -_padding; ti <= (input_length - kernel_size + _padding); ti += _stride) {
for (dim_t c = group_offset; c < (group_offset + in_channels_per_group * input_length); c += input_length) {
for (int k = 0; k < kernel_size; k++) {
auto window_i = k + ti;
if (0 <= window_i && window_i < input_length) {
out[out_offset] = in[window_i + c];
}
out_offset += 1;
}
}
}
}
}
}
}
}
#endif