#include "ctranslate2/decoding.h"
#include <algorithm>
#include <cmath>
#include <memory>
#include <numeric>
#include "ctranslate2/ops/ops.h"
#include "dispatch.h"
namespace ctranslate2 {
static const ops::Gather gather;
static void gather_beam_flat(StorageView& data, const StorageView& indices, dim_t beam_size) {
merge_batch_beam(data);
gather(data, indices);
split_batch_beam(data, beam_size);
}
static void update_sample_with_prefix(const size_t step,
StorageView& sampled_ids,
StorageView& sampled_scores,
const std::vector<std::vector<size_t>>& prefix_ids,
const std::vector<size_t>& end_ids,
const std::vector<dim_t>& batch_offset,
const dim_t beam_size = 1,
StorageView* beam_origins = nullptr,
const bool is_expanded = true) {
const dim_t batch_size = sampled_scores.dim(0);
for (dim_t i = 0; i < batch_size; ++i) {
const auto& prefix = prefix_ids[batch_offset[i]];
if (step > prefix.size())
continue;
const dim_t num_samples = sampled_scores.dim(1);
for (dim_t k = 0; k < num_samples; ++k) {
const dim_t flat_index = i * num_samples + k;
auto& sampled_id = sampled_ids.at<int32_t>(flat_index);
int32_t new_id = -1;
float new_score = 0;
if (step < prefix.size()) {
new_id = prefix[step];
new_score = (k == 0 ? 0.f : float(-1e10));
} else if (k > 0 && is_eos(sampled_id, end_ids)) {
new_id = 0;
new_score = -1e10;
}
if (new_id >= 0) {
sampled_id = new_id;
TYPE_DISPATCH(sampled_scores.dtype(), sampled_scores.at<T>(flat_index) = T(new_score));
if (beam_origins)
beam_origins->at<int32_t>(flat_index) = (is_expanded ? i * beam_size : i);
}
}
}
}
static inline void convert_to_original_word_ids(const layers::Decoder& decoder,
StorageView& ids) {
if (!decoder.output_layer_is_updated())
return;
auto* ids_data = ids.data<int32_t>();
for (dim_t i = 0; i < ids.size(); ++i)
ids_data[i] = decoder.to_original_word_id(ids_data[i]);
}
template <typename T>
static void initialize_beam_scores(StorageView& scores,
const dim_t batch_size,
const dim_t beam_size) {
const dim_t size = batch_size * beam_size;
scores.resize({size});
auto* data = scores.data<T>();
for (dim_t i = 0; i < size; ++i) {
data[i] = (i % beam_size == 0 ? T(0) : std::numeric_limits<T>::lowest());
}
}
static StorageView unflatten_ids(StorageView& ids,
const dim_t beam_size,
const dim_t vocabulary_size,
const bool is_expanded) {
const dim_t num_ids = ids.size();
StorageView beam_origins({num_ids}, DataType::INT32);
auto* ids_data = ids.data<int32_t>();
auto* origins_data = beam_origins.data<int32_t>();
for (dim_t i = 0; i < num_ids; ++i) {
const auto flat_id = ids_data[i];
const auto beam_id = flat_id / vocabulary_size;
const auto word_id = flat_id % vocabulary_size;
const auto batch_id = i / ids.dim(-1);
ids_data[i] = word_id;
origins_data[i] = is_expanded ? batch_id * beam_size + beam_id : batch_id;
}
return beam_origins;
}
static void append_step_output(StorageView& history, StorageView step_output, const StorageView* beam_origins = nullptr) {
step_output.expand_dims(2);
if (history) {
if (beam_origins)
gather_beam_flat(history, *beam_origins, step_output.dim(1));
const StorageView cur_history(std::move(history));
ops::Concat(2)({&cur_history, &step_output}, history);
} else {
history = std::move(step_output);
}
}
static std::vector<size_t> build_hypothesis(const StorageView& history,
const dim_t batch,
const dim_t beam,
const dim_t start,
const dim_t end) {
const auto* ids = history.index<int32_t>({batch, beam, 0});
return std::vector<size_t>(ids + start, ids + end);
}
static std::vector<std::vector<float>> build_attention(const StorageView& history,
const dim_t batch,
const dim_t beam,
const dim_t start,
const dim_t end) {
if (!history)
return {};
const auto source_length = history.dim(-1);
std::vector<std::vector<float>> attention;
attention.reserve(end - start);
for (dim_t t = start; t < end; ++t) {
const auto* vector = history.index<float>({batch, beam, t, 0});
attention.emplace_back(vector, vector + source_length);
}
return attention;
}
static float compute_coverage_penalty(const std::vector<std::vector<float>>& attention,
const float beta) {
float penalty = 0;
for (size_t column = 0; column < attention[0].size(); column++) {
float coverage = 0;
for (size_t row = 0; row < attention.size(); row++)
coverage += attention[row][column];
if (coverage > 0)
penalty += std::log(std::min(coverage, 1.f));
}
return beta * penalty;
}
static float finalize_hypothesis_score(float score,
const float length,
const float length_penalty,
const float coverage_penalty,
const std::vector<std::vector<float>>* attention) {
score /= std::pow(length, length_penalty);
if (coverage_penalty != 0) {
if (!attention)
throw std::runtime_error("The attention weights are required to apply the coverage penalty");
score += compute_coverage_penalty(*attention, coverage_penalty);
}
return score;
}
static inline void sort_hypotheses(DecodingResult& result,
size_t max_hypotheses,
bool keep_scores,
bool keep_attention) {
std::vector<size_t> idx(result.hypotheses.size());
std::iota(idx.begin(), idx.end(), 0);
std::sort(idx.begin(), idx.end(),
[&result](size_t i1, size_t i2) { return result.scores[i1] > result.scores[i2]; });
if (max_hypotheses < idx.size())
idx.resize(max_hypotheses);
result.hypotheses = index_vector(result.hypotheses, idx);
if (keep_scores)
result.scores = index_vector(result.scores, idx);
else
result.scores.clear();
if (keep_attention)
result.attention = index_vector(result.attention, idx);
else
result.attention.clear();
}
static inline void finalize_result(DecodingResult& result,
const size_t max_hypotheses,
const float length_penalty,
const float coverage_penalty,
const bool keep_scores,
const bool keep_attention) {
for (size_t i = 0; i < result.scores.size(); ++i) {
const auto* attention = result.attention.empty() ? nullptr : &result.attention[i];
result.scores[i] = finalize_hypothesis_score(result.scores[i],
result.hypotheses[i].size(),
length_penalty,
coverage_penalty,
attention);
}
sort_hypotheses(result, max_hypotheses, keep_scores, keep_attention);
}
BiasedDecoder::BiasedDecoder(const float prefix_bias_beta,
const std::vector<std::vector<size_t>>& prefix_ids)
: _prefix_bias_beta(prefix_bias_beta)
, _prefix_ids(prefix_ids)
{
}
void BiasedDecoder::decode(const dim_t cur_batch_size,
const size_t step,
const std::vector<dim_t>& batch_offset,
const std::vector<std::vector<bool>>& beams_diverged_from_prefix,
const StorageView& logits,
StorageView& log_probs) {
const dim_t num_beams = logits.dim(0);
const Device device = logits.device();
const DataType dtype = logits.dtype();
if (_spare_beam.dtype() != dtype || _spare_beam.device() != device) {
_spare_beam = StorageView(device, dtype);
}
std::vector<StorageView> logit_beam_view_storage(num_beams, StorageView(device, dtype));
std::vector<StorageView*> logit_beam_views(num_beams);
std::vector<StorageView> log_prob_beam_view_storage(num_beams, StorageView(device, dtype));
std::vector<StorageView*> log_prob_beam_views(num_beams);
for (dim_t i = 0; i < num_beams; ++i) {
logit_beam_views[i] = &(logit_beam_view_storage[i]);
log_prob_beam_views[i] = &(log_prob_beam_view_storage[i]);
}
ops::Split(0, true)(logits, logit_beam_views);
log_probs.resize_as(logits);
log_probs.reshape(logits.shape());
ops::Split(0, true)(log_probs, log_prob_beam_views);
StorageView scalar_discount(1 - _prefix_bias_beta, Device::CPU);
assert (num_beams % cur_batch_size == 0);
const dim_t cur_beam_size = num_beams / cur_batch_size;
for (dim_t b = 0; b < num_beams; ++b) {
StorageView &logit_beam = *(logit_beam_views[b]);
StorageView &log_prob_beam = *(log_prob_beam_views[b]);
const dim_t index_batch = b / cur_beam_size;
const dim_t index_beam = b % cur_beam_size;
const auto& prefix = _prefix_ids[batch_offset[index_batch]];
if (static_cast<size_t>(step) < prefix.size()
&& !beams_diverged_from_prefix[index_batch][index_beam]) {
ops::SoftMax()(logit_beam, log_prob_beam);
ops::Mul()(log_prob_beam,
scalar_discount.to(log_prob_beam.dtype()),
_spare_beam);
const size_t biased_word_id = prefix[step];
StorageView spare_scalar_view;
TYPE_DISPATCH(
_spare_beam.dtype(),
spare_scalar_view = StorageView({1}, _spare_beam.data<T>() + biased_word_id, device));
const StorageView spare_scalar_copy(spare_scalar_view);
StorageView beta_scalar;
TYPE_DISPATCH(
_spare_beam.dtype(),
beta_scalar = StorageView(static_cast<T>(_prefix_bias_beta), Device::CPU));
ops::Add()(spare_scalar_copy, beta_scalar, spare_scalar_view);
ops::Log()(_spare_beam, log_prob_beam);
} else {
ops::LogSoftMax()(logit_beam, log_prob_beam);
}
}
}
static inline std::vector<std::vector<bool>>
get_beams_divergence_from_prefix(const std::vector<std::vector<bool>>& beams_diverged_from_prefix,
const size_t step,
const StorageView& sampled_ids,
const std::vector<std::vector<size_t>>& prefix_ids,
const std::vector<dim_t>& batch_offset) {
auto updated = beams_diverged_from_prefix;
for (dim_t i = 0; i < dim_t(updated.size()); ++i) {
for (dim_t k = 0; k < dim_t(updated[i].size()); ++k) {
const size_t word_id = sampled_ids.at<int32_t>({i, k});
const auto& prefix = prefix_ids[batch_offset[i]];
updated[i][k] = (step >= prefix.size()
|| beams_diverged_from_prefix[i][k]
|| word_id != prefix[step]);
}
}
return updated;
}
static inline bool
all_beams_diverged_from_prefix(const std::vector<std::vector<bool>>& beams_diverged_from_prefix) {
for (const auto& batch : beams_diverged_from_prefix) {
for (const bool beam_diverged : batch) {
if (!beam_diverged)
return false;
}
}
return true;
}
static inline size_t get_max_candidates(const dim_t beam_size, const float patience) {
return std::round(float(beam_size) * patience);
}
static dim_t get_max_step(const dim_t max_length,
const bool return_prefix,
const std::vector<std::vector<size_t>>* prefix_ids) {
dim_t max_step = 0;
if (prefix_ids && !return_prefix) {
for (const auto& ids : *prefix_ids) {
const dim_t prefix_length = ids.size();
max_step = std::max(max_step, prefix_length + max_length);
}
} else {
max_step = max_length;
}
return max_step;
}
static inline bool is_last_step(const dim_t step,
const dim_t max_length,
const dim_t prefix_length,
const bool return_prefix) {
return step + 1 == max_length + (return_prefix ? 0 : prefix_length);
}
static void apply_min_length(const dim_t step,
const dim_t min_length,
const std::vector<size_t>& end_ids,
DisableTokens& disable_tokens,
const std::vector<dim_t>& batch_offset,
const bool return_prefix,
const std::vector<std::vector<size_t>>* prefix_ids) {
if (prefix_ids && !return_prefix) {
const size_t batch_size = batch_offset.size();
for (size_t i = 0; i < batch_size; ++i) {
const dim_t batch_id = batch_offset[i];
const dim_t prefix_length = prefix_ids->at(batch_id).size();
if (step < prefix_length + min_length) {
for (const size_t end_id : end_ids)
disable_tokens.add(i, end_id);
}
}
} else if (step < min_length) {
for (const size_t end_id : end_ids)
disable_tokens.add(end_id);
}
}
BeamSearch::BeamSearch(const dim_t beam_size,
const float length_penalty,
const float coverage_penalty,
const float prefix_bias_beta,
const float patience)
: _beam_size(beam_size)
, _length_penalty(length_penalty)
, _coverage_penalty(coverage_penalty)
, _prefix_bias_beta(prefix_bias_beta)
, _max_candidates(get_max_candidates(beam_size, patience))
{
}
std::vector<DecodingResult>
BeamSearch::search(layers::Decoder& decoder,
layers::DecoderState& state,
const Sampler& sampler,
const std::vector<size_t>& start_ids,
const std::vector<size_t>& end_ids,
const dim_t start_step,
const dim_t max_length,
const dim_t min_length,
const bool return_scores,
const bool return_attention,
const bool return_prefix,
const size_t num_hypotheses,
const bool include_eos_in_hypotheses,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors,
const std::vector<std::vector<size_t>>* prefix_ids) const {
PROFILE("beam_search");
const Device device = decoder.device();
const DataType dtype = decoder.output_type();
const dim_t vocabulary_size = decoder.output_size();
const dim_t batch_size = start_ids.size();
const dim_t num_candidates = _beam_size * 2;
const bool expand_after_first_step = (device == Device::CPU
&& num_candidates <= vocabulary_size);
const bool allow_early_exit = (_length_penalty == 0 && _coverage_penalty == 0);
StorageView topk_ids({batch_size}, DataType::INT32);
StorageView topk_scores(dtype);
std::vector<bool> top_beam_finished(batch_size, false);
std::vector<dim_t> batch_offset(batch_size);
std::vector<DecodingResult> results(batch_size);
for (dim_t i = 0; i < batch_size; ++i) {
batch_offset[i] = i;
topk_ids.at<int32_t>(i) = start_ids[i];
}
if (!expand_after_first_step) {
decoder.replicate_state(state, _beam_size);
repeat_batch(topk_ids, _beam_size);
TYPE_DISPATCH(dtype, initialize_beam_scores<T>(topk_scores, batch_size, _beam_size));
}
std::unique_ptr<BiasedDecoder> biased_decoder;
std::vector<std::vector<bool>> beams_diverged_from_prefix;
bool bias_towards_prefix = prefix_ids && _prefix_bias_beta > 0;
if (bias_towards_prefix) {
biased_decoder = std::make_unique<BiasedDecoder>(_prefix_bias_beta, *prefix_ids);
beams_diverged_from_prefix.resize(batch_size, std::vector<bool>(_beam_size, false));
}
const bool use_hard_prefix = prefix_ids && !bias_towards_prefix;
StorageView logits(dtype, device);
StorageView alive_seq(topk_ids.dtype());
StorageView alive_attention;
const dim_t max_step = get_max_step(max_length,
return_prefix,
use_hard_prefix ? prefix_ids : nullptr);
for (dim_t step = 0; step < max_step; ++step) {
const bool is_expanded = (!expand_after_first_step || step > 0);
StorageView attention_step(dtype, device);
convert_to_original_word_ids(decoder, topk_ids);
decoder(start_step + step,
topk_ids.to(device),
state,
&logits, (return_attention || _coverage_penalty != 0) ? &attention_step : nullptr);
const dim_t cur_batch_size = is_expanded ? logits.dim(0) / _beam_size : logits.dim(0);
DisableTokens disable_tokens(logits);
apply_min_length(step,
min_length,
end_ids,
disable_tokens,
batch_offset,
return_prefix,
prefix_ids);
if (!logits_processors.empty()) {
if (alive_seq)
merge_batch_beam(alive_seq);
for (const auto& logits_processor : logits_processors)
logits_processor->apply(step, logits, disable_tokens, alive_seq, batch_offset, prefix_ids);
if (alive_seq)
split_batch_beam(alive_seq, _beam_size);
}
disable_tokens.apply();
StorageView log_probs(dtype, device);
if (bias_towards_prefix) {
biased_decoder->decode(cur_batch_size,
step,
batch_offset,
beams_diverged_from_prefix,
logits,
log_probs);
} else {
ops::LogSoftMax()(logits);
log_probs.shallow_copy(logits);
}
if (topk_scores) {
DEVICE_AND_TYPE_DISPATCH(log_probs.device(), log_probs.dtype(),
primitives<D>::add_depth_broadcast(topk_scores.to(device).data<T>(),
log_probs.data<T>(),
topk_scores.size(),
log_probs.size()));
}
log_probs.reshape({cur_batch_size, -1});
sampler(log_probs, topk_ids, topk_scores, num_candidates);
StorageView gather_indices = unflatten_ids(topk_ids, _beam_size, vocabulary_size, is_expanded);
if (prefix_ids) {
if (use_hard_prefix) {
update_sample_with_prefix(step,
topk_ids,
topk_scores,
*prefix_ids,
end_ids,
batch_offset,
_beam_size,
&gather_indices,
is_expanded);
} else if (bias_towards_prefix) {
beams_diverged_from_prefix = get_beams_divergence_from_prefix(beams_diverged_from_prefix,
step,
topk_ids,
*prefix_ids,
batch_offset);
}
}
append_step_output(alive_seq, topk_ids, &gather_indices);
if (attention_step) {
if (!is_expanded)
repeat_batch(attention_step, _beam_size);
split_batch_beam(attention_step, _beam_size);
append_step_output(alive_attention, attention_step.to_float32().to(Device::CPU));
gather_beam_flat(alive_attention, gather_indices, num_candidates);
}
std::vector<int32_t> non_finished_index;
non_finished_index.reserve(cur_batch_size);
StorageView active_beams({cur_batch_size * _beam_size}, DataType::INT32);
for (dim_t i = 0; i < cur_batch_size; ++i) {
const dim_t batch_id = batch_offset[i];
const dim_t prefix_length = use_hard_prefix ? prefix_ids->at(batch_id).size() : 0;
const bool is_last_step_for_batch = is_last_step(step,
max_length,
prefix_length,
return_prefix);
auto& result = results[batch_id];
dim_t secondary_candidates_offset = _beam_size;
for (dim_t k = 0; k < _beam_size; ++k) {
const size_t last_id = topk_ids.at<int32_t>({i, k});
dim_t next_beam_id = k;
if ((is_eos(last_id, end_ids) && step >= prefix_length) || is_last_step_for_batch) {
if (k == 0)
top_beam_finished[i] = true;
const bool ignore_last_token = is_eos(last_id, end_ids) && !include_eos_in_hypotheses;
const dim_t start = return_prefix ? 0 : prefix_length;
const dim_t end = ignore_last_token ? step : step + 1;
result.scores.emplace_back(topk_scores.scalar_at<float>({i, k}));
result.hypotheses.emplace_back(build_hypothesis(alive_seq, i, k, start, end));
if (alive_attention)
result.attention.emplace_back(build_attention(alive_attention, i, k, start, end));
for (dim_t j = secondary_candidates_offset; j < num_candidates; ++j) {
const auto candidate = topk_ids.at<int32_t>({i, j});
if (!is_eos(candidate, end_ids)) {
next_beam_id = j;
secondary_candidates_offset = j + 1;
break;
}
}
}
active_beams.at<int32_t>(i * _beam_size + k) = i * num_candidates + next_beam_id;
}
bool is_finished = false;
if (is_last_step_for_batch)
is_finished = true;
else if (allow_early_exit)
is_finished = top_beam_finished[i] && result.hypotheses.size() >= num_hypotheses;
else
is_finished = result.hypotheses.size() >= _max_candidates;
if (is_finished) {
finalize_result(result,
num_hypotheses,
_length_penalty,
_coverage_penalty,
return_scores,
return_attention);
} else {
non_finished_index.emplace_back(i);
}
}
const dim_t next_batch_size = non_finished_index.size();
if (next_batch_size == 0) {
if (!is_expanded) {
decoder.replicate_state(state, _beam_size);
}
break;
}
gather(gather_indices, active_beams);
gather_beam_flat(topk_ids, active_beams, _beam_size);
gather_beam_flat(topk_scores, active_beams, _beam_size);
gather_beam_flat(alive_seq, active_beams, _beam_size);
if (alive_attention)
gather_beam_flat(alive_attention, active_beams, _beam_size);
std::unique_ptr<StorageView> keep_batches;
if (next_batch_size != cur_batch_size) {
batch_offset = index_vector(batch_offset, non_finished_index);
top_beam_finished = index_vector(top_beam_finished, non_finished_index);
if (bias_towards_prefix)
beams_diverged_from_prefix = index_vector(beams_diverged_from_prefix, non_finished_index);
keep_batches = std::make_unique<StorageView>(Shape{next_batch_size}, non_finished_index);
gather(topk_ids, *keep_batches);
gather(topk_scores, *keep_batches);
gather(alive_seq, *keep_batches);
if (alive_attention)
gather(alive_attention, *keep_batches);
if (keep_batches->device() != device)
*keep_batches = keep_batches->to(device);
}
if (gather_indices.device() != device)
gather_indices = gather_indices.to(device);
decoder.update_state(state, gather_indices, _beam_size, keep_batches.get());
topk_ids.reshape({next_batch_size * _beam_size});
topk_scores.reshape({next_batch_size * _beam_size});
if (bias_towards_prefix)
bias_towards_prefix = !all_beams_diverged_from_prefix(beams_diverged_from_prefix);
}
return results;
}
GreedySearch::GreedySearch(const float length_penalty,
const float coverage_penalty,
std::function<bool(DecodingStepResult)> callback)
: _length_penalty(length_penalty)
, _coverage_penalty(coverage_penalty)
, _callback(std::move(callback))
{
}
std::vector<DecodingResult>
GreedySearch::search(layers::Decoder& decoder,
layers::DecoderState& state,
const Sampler& sampler,
const std::vector<size_t>& start_ids,
const std::vector<size_t>& end_ids,
const dim_t start_step,
const dim_t max_length,
const dim_t min_length,
const bool return_scores,
const bool return_attention,
const bool return_prefix,
const size_t num_hypotheses,
const bool include_eos_in_hypotheses,
const std::vector<std::shared_ptr<LogitsProcessor>>& logits_processors,
const std::vector<std::vector<size_t>>* prefix_ids) const {
const dim_t batch_size = start_ids.size();
if (num_hypotheses > 1) {
for (auto& [name, value] : state) {
if (value)
repeat_batch(value, num_hypotheses);
}
std::vector<size_t> repeat_start_ids = repeat_vector(start_ids, num_hypotheses);
std::vector<std::vector<size_t>> repeat_prefix_ids;
if (prefix_ids)
repeat_prefix_ids = repeat_vector(*prefix_ids, num_hypotheses);
std::unique_ptr<GreedySearch> greedy;
if (_callback) {
auto hypothesis_callback = [this, num_hypotheses](DecodingStepResult result) {
result.hypothesis_id = result.batch_id % num_hypotheses;
result.batch_id /= num_hypotheses;
return _callback(std::move(result));
};
greedy = std::make_unique<GreedySearch>(_length_penalty,
_coverage_penalty,
std::move(hypothesis_callback));
}
std::vector<DecodingResult> results = (greedy ? greedy.get() : this)->search(
decoder,
state,
sampler,
repeat_start_ids,
end_ids,
start_step,
max_length,
min_length,
true,
return_attention,
return_prefix,
1,
include_eos_in_hypotheses,
logits_processors,
prefix_ids ? &repeat_prefix_ids : nullptr);
std::vector<DecodingResult> final_results(batch_size);
for (size_t i = 0; i < results.size(); ++i) {
auto& result = results[i];
auto& final_result = final_results[i / num_hypotheses];
final_result.hypotheses.emplace_back(std::move(result.hypotheses[0]));
final_result.scores.emplace_back(result.scores[0]);
if (return_attention)
final_result.attention.emplace_back(std::move(result.attention[0]));
}
for (auto& result : final_results)
sort_hypotheses(result, num_hypotheses, return_scores, return_attention);
return final_results;
}
PROFILE("greedy_search");
const Device device = decoder.device();
const DataType dtype = decoder.output_type();
const bool gather_attention = (return_attention || (return_scores && _coverage_penalty != 0));
StorageView sample_from({batch_size}, DataType::INT32);
StorageView logits(dtype, device);
std::vector<dim_t> batch_offset(batch_size);
std::vector<DecodingResult> results(batch_size);
for (dim_t i = 0; i < batch_size; ++i) {
batch_offset[i] = i;
sample_from.at<int32_t>(i) = start_ids[i];
results[i].hypotheses.resize(1);
if (return_scores)
results[i].scores.resize(1, 0.f);
if (return_attention)
results[i].attention.resize(1);
}
StorageView best_ids(DataType::INT32);
StorageView best_probs(dtype);
StorageView alive_seq(DataType::INT32);
StorageView attention_step;
StorageView attention_step_device(dtype, device);
const dim_t max_step = get_max_step(max_length, return_prefix, prefix_ids);
for (dim_t step = 0; step < max_step; ++step) {
convert_to_original_word_ids(decoder, sample_from);
decoder(start_step + step,
sample_from.to(device),
state,
&logits,
gather_attention ? &attention_step_device : nullptr);
DisableTokens disable_tokens(logits);
apply_min_length(step,
min_length,
end_ids,
disable_tokens,
batch_offset,
return_prefix,
prefix_ids);
for (const auto& logits_processor : logits_processors)
logits_processor->apply(step, logits, disable_tokens, alive_seq, batch_offset, prefix_ids);
disable_tokens.apply();
StorageView log_probs(dtype, device);
if (return_scores)
ops::LogSoftMax()(logits);
log_probs.shallow_copy(logits);
sampler(log_probs, best_ids, best_probs);
if (prefix_ids)
update_sample_with_prefix(step, best_ids, best_probs, *prefix_ids, end_ids, batch_offset);
if (attention_step_device)
attention_step.copy_from(attention_step_device.to_float32());
if (!logits_processors.empty()) {
if (alive_seq) {
const StorageView cur_alive_seq = std::move(alive_seq);
ops::Concat(-1)({&cur_alive_seq, &best_ids}, alive_seq);
} else {
alive_seq = best_ids;
}
}
const dim_t cur_batch_size = log_probs.dim(0);
std::vector<int32_t> non_finished_index;
non_finished_index.reserve(cur_batch_size);
for (dim_t i = 0; i < cur_batch_size; ++i) {
const size_t word_id = best_ids.at<int32_t>(i);
const size_t batch_id = batch_offset[i];
const dim_t prefix_length = prefix_ids ? prefix_ids->at(batch_id).size() : 0;
const float score = best_probs.scalar_at<float>({i, 0});
if ((!is_eos(word_id, end_ids) || include_eos_in_hypotheses)
&& (return_prefix || step >= prefix_length)) {
results[batch_id].hypotheses[0].push_back(word_id);
if (attention_step) {
const auto* attn = attention_step.index<float>({i, 0});
results[batch_id].attention[0].emplace_back(attn, attn + attention_step.dim(-1));
}
}
if (return_scores)
results[batch_id].scores[0] += score;
bool is_finished = ((is_eos(word_id, end_ids) && step >= prefix_length)
|| (is_last_step(step, max_length, prefix_length, return_prefix)));
if (_callback && (return_prefix || step >= prefix_length)) {
DecodingStepResult step_result;
step_result.step = step;
step_result.batch_id = batch_id;
step_result.token_id = word_id;
step_result.hypothesis_id = 0;
step_result.is_last = is_finished;
if (return_scores)
step_result.log_prob = score;
if (_callback(std::move(step_result))) {
is_finished = true;
}
}
if (is_finished) {
finalize_result(results[batch_id],
1,
_length_penalty,
_coverage_penalty,
return_scores,
return_attention);
} else {
non_finished_index.emplace_back(i);
sample_from.at<int32_t>(i) = word_id;
}
}
const dim_t count_alive = non_finished_index.size();
if (count_alive == 0)
break;
if (count_alive != cur_batch_size) {
batch_offset = index_vector(batch_offset, non_finished_index);
StorageView alive({count_alive}, non_finished_index);
if (alive_seq)
gather(alive_seq, alive);
gather(sample_from, alive);
decoder.update_state(state, alive.to(device));
}
}
return results;
}
static layers::DecoderState get_batch_state(const layers::DecoderState& state,
const int32_t batch_id) {
const Device device = state.begin()->second.device();
const ops::Gather gather_op;
StorageView indices(batch_id, device);
indices.reshape({1});
layers::DecoderState batch_state;
batch_state.reserve(state.size());
for (const auto& pair : state) {
const auto& name = pair.first;
const auto& value = pair.second;
StorageView batch_value(value.dtype(), device);
if (value)
gather_op(value, indices, batch_value);
batch_state.emplace(name, std::move(batch_value));
}
return batch_state;
}
static std::pair<std::vector<size_t>, std::vector<std::vector<size_t>>>
split_start_tokens(const std::vector<std::vector<size_t>>& start_tokens) {
std::vector<size_t> start_ids;
std::vector<std::vector<size_t>> prefix_ids;
start_ids.reserve(start_tokens.size());
prefix_ids.reserve(start_tokens.size());
bool only_start_token = true;
for (const auto& tokens : start_tokens) {
if (tokens.empty())
throw std::invalid_argument("One input has no decoder start token");
if (tokens.size() > 1)
only_start_token = false;
start_ids.emplace_back(tokens.front());
prefix_ids.emplace_back(tokens.begin() + 1, tokens.end());
}
if (only_start_token)
prefix_ids.clear();
return std::make_pair(std::move(start_ids), std::move(prefix_ids));
}
static void validate_decoding_options(const DecodingOptions& options, const Device device) {
if (options.beam_size == 0)
throw std::invalid_argument("The beam size must be > 0");
if (options.patience <= 0)
throw std::invalid_argument("The patience factor must be > 0");
if (options.num_hypotheses == 0)
throw std::invalid_argument("The number of hypotheses must be > 0");
if (options.num_hypotheses > get_max_candidates(options.beam_size, options.patience)
&& !options.return_alternatives
&& !(options.beam_size == 1 && options.sampling_topk != 1))
throw std::invalid_argument("The number of hypotheses cannot be greater than "
"beam_size * patience");
if (options.min_length > options.max_length)
throw std::invalid_argument("The minimum decoding length is greater than "
"the maximum decoding length");
if (options.max_length == 0)
throw std::invalid_argument("The maximum decoding length must be > 0");
if (options.repetition_penalty <= 0)
throw std::invalid_argument("The repetition penalty must be > 0");
if (options.prefix_bias_beta >= 1)
throw std::invalid_argument("The beta value in biased decoding must be < 1");
if (options.prefix_bias_beta > 0 && options.return_alternatives)
throw std::invalid_argument("Biased decoding is not compatible with the return_alternatives "
"mode");
if (options.return_alternatives
&& (options.min_alternative_expansion_prob < 0
|| options.min_alternative_expansion_prob > 1))
throw std::invalid_argument("The minimum alternative expansion probability must be "
"between 0 and 1");
if (options.callback && (options.beam_size != 1 || options.prefix_bias_beta > 0))
throw std::invalid_argument("The callback function is not compatible with "
"beam_size > 1 or prefix_bias_beta > 0");
if (options.sampling_topp <= 0 || options.sampling_topp > 1)
throw std::invalid_argument("The sampling_topp parameter must be between 0 and 1");
if (options.sampling_topp < 1
&& options.sampling_topk > static_cast<size_t>(ops::TopPMask::max_num_classes(device)))
throw std::invalid_argument(
"The sampling_topp parameter currently requires sampling_topk <= "
+ std::to_string(ops::TopPMask::max_num_classes(device))
+ " when running on a " + device_to_str(device) + " device");
}
static std::unique_ptr<const Sampler>
make_sampler(const DecodingOptions& options) {
if (options.sampling_topk == 1 || options.sampling_temperature == 0.0)
return std::make_unique<BestSampler>();
else
return std::make_unique<RandomSampler>(options.sampling_topk,
options.sampling_topp,
options.sampling_temperature);
}
static std::unique_ptr<const SearchStrategy>
make_search_strategy(const DecodingOptions& options) {
if (options.beam_size == 1 && options.prefix_bias_beta == 0)
return std::make_unique<GreedySearch>(options.length_penalty,
options.coverage_penalty,
options.callback);
else
return std::make_unique<BeamSearch>(options.beam_size,
options.length_penalty,
options.coverage_penalty,
options.prefix_bias_beta,
options.patience);
}
static std::vector<std::shared_ptr<LogitsProcessor>>
make_logits_processors(const DecodingOptions& options) {
std::vector<std::shared_ptr<LogitsProcessor>> processors;
for (const auto& processor : options.logits_processors) {
if (processor->apply_first())
processors.emplace_back(processor);
}
if (options.repetition_penalty != 1)
processors.emplace_back(std::make_shared<RepetitionPenalty>(options.repetition_penalty));
if (options.no_repeat_ngram_size > 0)
processors.emplace_back(std::make_shared<NoRepeatNgram>(options.no_repeat_ngram_size));
if (!options.disable_ids.empty())
processors.emplace_back(std::make_shared<SuppressTokens>(options.disable_ids));
if (!options.disable_ids_begin.empty())
processors.emplace_back(std::make_shared<SuppressTokensBegin>(options.disable_ids_begin));
if (!options.disable_sequences.empty())
processors.emplace_back(std::make_shared<SuppressSequences>(options.disable_sequences));
for (const auto& processor : options.logits_processors) {
if (!processor->apply_first())
processors.emplace_back(processor);
}
return processors;
}
static DecodingResult
decode_alternatives(layers::Decoder& decoder,
layers::DecoderState& state,
std::vector<size_t> start_tokens,
const std::vector<size_t>& end_ids,
const DecodingOptions& options) {
DecodingResult result;
result.hypotheses.resize(options.num_hypotheses);
if (options.return_scores)
result.scores.resize(options.num_hypotheses, 0);
if (options.return_attention)
result.attention.resize(options.num_hypotheses);
if (start_tokens.empty())
throw std::invalid_argument("One input has no decoder start token");
if (start_tokens.size() > options.max_length + 1)
start_tokens.resize(options.max_length + 1);
const dim_t min_length = options.min_length;
const dim_t max_length = options.max_length;
const dim_t prefix_length = start_tokens.size() - 1;
dim_t start_step = options.start_step;
if (prefix_length > 0) {
const Device device = decoder.device();
StorageView attention(decoder.output_type(), device);
StorageView input_ids({1, prefix_length},
std::vector<int32_t>(start_tokens.begin(),
start_tokens.begin() + prefix_length),
device);
convert_to_original_word_ids(decoder, input_ids);
decoder(start_step,
input_ids,
state,
nullptr,
options.return_attention ? &attention : nullptr);
for (size_t i = 0; i < options.num_hypotheses; ++i) {
result.hypotheses[i] = std::vector<size_t>(start_tokens.begin() + 1, start_tokens.end());
if (options.return_attention) {
if (attention.device() != Device::CPU)
attention = attention.to_float32().to(Device::CPU);
for (dim_t t = 0; t < prefix_length; ++t) {
const float* vector = attention.index<float>({0, t, 0});
result.attention[i].emplace_back(vector, vector + attention.dim(-1));
}
}
}
if (prefix_length == max_length)
return result;
start_step += prefix_length;
}
std::vector<size_t> start_ids{start_tokens.back()};
const auto logits_processors = make_logits_processors(options);
BeamSearch beam(options.num_hypotheses);
DecodingResult expansion_result = beam.search(decoder,
state,
BestSampler(),
start_ids,
end_ids,
start_step,
1,
1,
true,
options.return_attention,
options.return_prefix,
options.num_hypotheses,
options.include_eos_in_hypotheses,
logits_processors)[0];
start_ids.clear();
for (size_t i = 0; i < options.num_hypotheses; ++i) {
const float prob = std::exp(expansion_result.scores[i]);
if (prob < options.min_alternative_expansion_prob)
break;
result.hypotheses[i].emplace_back(expansion_result.hypotheses[i].back());
if (options.return_attention)
result.attention[i].emplace_back(std::move(expansion_result.attention[i].back()));
if (options.return_scores)
result.scores[i] = expansion_result.scores[i];
start_ids.push_back(result.hypotheses[i].back());
}
const size_t num_alternatives = start_ids.size();
for (auto& [name, value] : state) {
if (decoder.replicate_state(name)) {
if (num_alternatives < options.num_hypotheses)
value.resize(0, num_alternatives);
} else {
repeat_batch(value, num_alternatives);
}
}
if (num_alternatives < options.num_hypotheses) {
result.hypotheses.resize(num_alternatives);
if (options.return_scores)
result.scores.resize(num_alternatives);
if (options.return_attention)
result.attention.resize(num_alternatives);
}
start_step += 1;
if (start_step == max_length)
return result;
const auto search_strategy = make_search_strategy(options);
const auto sampler = make_sampler(options);
auto suffix_results = search_strategy->search(decoder,
state,
*sampler,
start_ids,
end_ids,
start_step,
std::max(max_length - start_step, dim_t(0)),
std::max(min_length - start_step, dim_t(0)),
options.return_scores,
options.return_attention,
options.return_prefix,
1,
options.include_eos_in_hypotheses,
logits_processors);
for (size_t i = 0; i < suffix_results.size(); ++i) {
auto& suffix = suffix_results[i];
if (options.return_scores) {
result.scores[i] += suffix.scores[0];
}
if (options.return_attention)
result.attention[i].insert(result.attention[i].end(),
std::make_move_iterator(suffix.attention[0].begin()),
std::make_move_iterator(suffix.attention[0].end()));
result.hypotheses[i].insert(result.hypotheses[i].end(),
std::make_move_iterator(suffix.hypotheses[0].begin()),
std::make_move_iterator(suffix.hypotheses[0].end()));
}
return result;
}
static std::vector<size_t> map_to_output_word_ids(const layers::Decoder& decoder,
const std::vector<size_t>& ids) {
std::vector<size_t> new_ids;
new_ids.reserve(ids.size());
for (const size_t id : ids) {
if (decoder.is_in_output(id))
new_ids.push_back(decoder.to_output_word_id(id));
}
return new_ids;
}
std::vector<DecodingResult>
decode(layers::Decoder& decoder,
layers::DecoderState& state,
std::vector<std::vector<size_t>> start_tokens,
std::vector<size_t> end_ids,
DecodingOptions options) {
validate_decoding_options(options, decoder.device());
const size_t batch_size = start_tokens.size();
if (batch_size == 0)
throw std::invalid_argument("No decoder start tokens are set");
std::vector<DecodingResult> results;
if (decoder.output_layer_is_updated()) {
end_ids = map_to_output_word_ids(decoder, end_ids);
for (auto& ids : start_tokens)
ids = map_to_output_word_ids(decoder, ids);
for (auto& ids : options.disable_sequences)
ids = map_to_output_word_ids(decoder, ids);
options.disable_ids = map_to_output_word_ids(decoder, options.disable_ids);
options.disable_ids_begin = map_to_output_word_ids(decoder, options.disable_ids_begin);
}
if (options.return_alternatives) {
results.reserve(batch_size);
for (size_t i = 0; i < batch_size; ++i) {
layers::DecoderState batch_state = get_batch_state(state, i);
results.emplace_back(decode_alternatives(decoder,
batch_state,
start_tokens[i],
end_ids,
options));
}
} else {
std::vector<size_t> start_ids;
std::vector<std::vector<size_t>> prefix_ids;
std::tie(start_ids, prefix_ids) = split_start_tokens(start_tokens);
const auto search_strategy = make_search_strategy(options);
const auto sampler = make_sampler(options);
const auto logits_processors = make_logits_processors(options);
results = search_strategy->search(decoder,
state,
*sampler,
start_ids,
end_ids,
options.start_step,
options.max_length,
options.min_length,
options.return_scores,
options.return_attention,
options.return_prefix,
options.num_hypotheses,
options.include_eos_in_hypotheses,
logits_processors,
prefix_ids.empty() ? nullptr : &prefix_ids);
}
if (decoder.output_layer_is_updated()) {
for (auto& result : results) {
for (auto& hypothesis : result.hypotheses) {
for (auto& id : hypothesis)
id = decoder.to_original_word_id(id);
}
}
}
return results;
}
}