#include "ctranslate2/decoding_utils.h"
#include <set>
#include "ctranslate2/ops/ops.h"
#include "dispatch.h"
namespace ctranslate2 {
DisableTokens::DisableTokens(StorageView& logits, const float disable_value)
: _logits(logits)
, _logits_data(logits.device() == Device::CPU ? logits.data<float>() : nullptr)
, _disable_value(disable_value)
, _batch_size(logits.dim(0))
, _vocabulary_size(logits.dim(1))
{
}
void DisableTokens::apply() {
const dim_t num_indices = _flat_indices.size();
if (num_indices == 0)
return;
const Device device = _logits.device();
const DataType dtype = _logits.dtype();
const StorageView flat_indices({num_indices}, _flat_indices, device);
DEVICE_AND_TYPE_DISPATCH(device, dtype,
primitives<D>::indexed_fill(_logits.data<T>(),
static_cast<T>(_disable_value),
flat_indices.data<int32_t>(),
num_indices));
_flat_indices.clear();
}
RepetitionPenalty::RepetitionPenalty(const float penalty)
: _penalty(penalty)
{
}
void RepetitionPenalty::apply(dim_t,
StorageView& logits,
DisableTokens&,
const StorageView& sequences,
const std::vector<dim_t>&,
const std::vector<std::vector<size_t>>*) {
if (!sequences)
return;
const Device device = logits.device();
const DataType dtype = logits.dtype();
StorageView previous_ids = sequences.to(device);
StorageView previous_scores(device, dtype);
ops::Gather(-1, 1)(logits, previous_ids, previous_scores);
DEVICE_AND_TYPE_DISPATCH(device, dtype,
primitives<D>::penalize_previous_tokens(logits.data<T>(),
previous_scores.data<T>(),
previous_ids.data<int32_t>(),
static_cast<T>(_penalty),
logits.dim(0),
previous_ids.dim(-1),
logits.dim(-1)));
}
NoRepeatNgram::NoRepeatNgram(const size_t ngram_size)
: _ngram_size(ngram_size)
{
}
void NoRepeatNgram::apply(dim_t,
StorageView&,
DisableTokens& disable_tokens,
const StorageView& sequences,
const std::vector<dim_t>&,
const std::vector<std::vector<size_t>>*) {
if (!sequences || sequences.dim(-1) < _ngram_size)
return;
const dim_t batch_size = sequences.dim(0);
const dim_t length = sequences.dim(1);
for (dim_t batch_id = 0; batch_id < batch_size; ++batch_id) {
const auto* begin = sequences.index<int32_t>({batch_id, 0});
const auto* end = begin + length;
const auto* current_ngram_begin = end - _ngram_size + 1;
std::set<size_t> ngram_final_tokens;
while (true) {
begin = std::search(begin, end, current_ngram_begin, end);
if (begin + _ngram_size > end)
break;
ngram_final_tokens.emplace(begin[_ngram_size - 1]);
begin += 1;
}
for (const auto token_id : ngram_final_tokens)
disable_tokens.add(batch_id, token_id);
}
}
SuppressSequences::SuppressSequences(std::vector<std::vector<size_t>> sequences) {
for (auto& sequence : sequences) {
if (sequence.empty())
continue;
if (sequence.size() == 1) _ids.emplace_back(sequence[0]);
else
_sequences.emplace_back(std::move(sequence));
}
}
void SuppressSequences::apply(dim_t,
StorageView&,
DisableTokens& disable_tokens,
const StorageView& sequences,
const std::vector<dim_t>&,
const std::vector<std::vector<size_t>>*) {
for (const auto token_id : _ids)
disable_tokens.add(token_id);
if (!sequences)
return;
const dim_t batch_size = sequences.dim(0);
const dim_t length = sequences.dim(1);
for (dim_t batch_id = 0; batch_id < batch_size; ++batch_id) {
const auto* begin = sequences.index<int32_t>({batch_id, 0});
const auto* end = begin + length;
for (const auto& banned_sequence : _sequences) {
const dim_t compare_length = banned_sequence.size() - 1;
if (length < compare_length)
continue;
const bool disable_last = std::equal(end - compare_length,
end,
banned_sequence.begin(),
banned_sequence.begin() + compare_length);
if (disable_last)
disable_tokens.add(batch_id, banned_sequence.back());
}
}
}
SuppressTokens::SuppressTokens(std::vector<size_t> ids)
: _ids(std::move(ids))
{
}
void SuppressTokens::apply(dim_t,
StorageView&,
DisableTokens& disable_tokens,
const StorageView&,
const std::vector<dim_t>&,
const std::vector<std::vector<size_t>>*) {
for (const auto token_id : _ids)
disable_tokens.add(token_id);
}
SuppressTokensBegin::SuppressTokensBegin(std::vector<size_t> ids)
: _ids(std::move(ids))
{
}
void SuppressTokensBegin::apply(dim_t step,
StorageView& logits,
DisableTokens& disable_tokens,
const StorageView&,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<size_t>>* prefix) {
const dim_t batch_size = logits.dim(0);
for (dim_t batch_id = 0; batch_id < batch_size; ++batch_id) {
const dim_t sample_begin = get_sample_begin(batch_size, batch_id, batch_offset, prefix);
if (step != sample_begin)
continue;
for (const auto token_id : _ids)
disable_tokens.add(batch_id, token_id);
}
}
}