#define LLAMA_API_INTERNAL
#include "sampling.h"
#include "llama-vocab.h"
#include "common.h"
#include "reasoning-budget.cpp"
#include <limits>
#include <random>
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
#include <immintrin.h>
#endif
#include <nlohmann/json.hpp>
using json = nlohmann::ordered_json;
struct llama_sampler_adaptive_p * llama_clone_adaptive_p(const struct llama_sampler_adaptive_p * adapt_p_ctx);
void llama_free_adaptive_p(struct llama_sampler_adaptive_p * adapt_p_ctx);
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
struct common_sampler * result = new common_sampler();
result->params = params;
result->grammar = nullptr;
result->rbudget = nullptr;
struct llama_grammar* grmr = nullptr;
const std::string & grammar_str = common_grammar_value(params.grammar);
if (grammar_str.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
result->grammar = grmr;
#else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif }
else {
std::vector<std::string> trigger_patterns;
std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
trigger_patterns.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{
trigger_patterns.push_back(trigger.value);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{
const auto & pattern = trigger.value;
std::string anchored = "^$";
if (!pattern.empty()) {
anchored = (pattern.front() != '^' ? "^" : "")
+ pattern
+ (pattern.back() != '$' ? "$" : "");
}
trigger_patterns.push_back(anchored);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
{
const auto token = trigger.token;
trigger_tokens.push_back(token);
break;
}
default:
GGML_ASSERT(false && "unknown trigger type");
}
}
std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) {
trigger_patterns_c.push_back(regex.c_str());
}
if (!grammar_str.empty()) {
grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy_patterns(vocab, grammar_str.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size())
: llama_sampler_init_grammar(vocab, grammar_str.c_str(), "root");
if (grmr) {
result->prev.resize(params.n_prev);
result->grammar = grmr;
}
}
result->n_valid = 0;
result->grammar_str = grammar_str;
result->grammar_root = "root";
}
std::vector<llama_token> prefill_tokens;
if (!params.generation_prompt.empty()) {
GGML_ASSERT(vocab != nullptr);
auto tokens = common_tokenize(vocab, params.generation_prompt, false, true);
for (size_t i = 0; i < tokens.size(); i++) {
std::string piece = common_token_to_piece(vocab, tokens[i], true);
if (i == 0 && std::isspace(piece[0]) && !std::isspace(params.generation_prompt[0])) {
continue;
}
LOG_DBG("%s: prefill token: %d = %s\n", __func__, tokens[i], piece.c_str());
prefill_tokens.push_back(tokens[i]);
}
}
if (grmr && !params.grammar_lazy && common_grammar_needs_prefill(params.grammar)) {
try {
for (const auto & token : prefill_tokens) {
llama_grammar_accept_impl(*grmr, vocab, nullptr, token);
LOG_DBG("%s: grammar accepted prefill token (%d)\n", __func__, token);
}
}
catch (std::exception & e) {
LOG_ERR("%s: error initializing grammar sampler for grammar:\n%s\n\nGeneration prompt:\n'%s'\n", __func__,
common_grammar_value(params.grammar).c_str(), params.generation_prompt.c_str());
throw e;
}
}
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty() && (params.grammar_lazy || params.reasoning_budget_tokens >= 0)) {
result->rbudget = common_reasoning_budget_init(
vocab,
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens);
for (const auto & token : prefill_tokens) {
common_reasoning_budget_accept(result->rbudget, token);
LOG_DBG("%s: reasoning-budget accepted prefill token (%d)\n", __func__, token);
}
}
llama_sampling_set_rng_seed(result, params.seed);
for (const auto& cnstr : params.samplers_sequence)
{
switch (cnstr)
{
case llama_sampler_type::DRY:
{
std::vector<const char*> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto& str : params.dry_sequence_breakers)
{
c_breakers.push_back(str.c_str());
}
result->smpl=llama_sampler_init_dry(vocab, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size());
break;
}
case llama_sampler_type::ADAPTIVE_P:
{
if (params.adaptive_target >= 0.0f) {
GGML_ASSERT(vocab);
auto n_vocab = llama_vocab_n_tokens(vocab);
result->adapt_p_ctx = llama_init_adaptive_p(n_vocab, params.adaptive_target, params.adaptive_decay, params.adaptive_updt_w_cur, result->rng());
}
break;
}
default:
break;
}
}
result->elb_idx = 0;
result->elb_search_pos = 0;
return result;
}
void common_sampler_free(struct common_sampler * ctx) {
if (!ctx) {
return;
}
if (ctx->grammar) {
llama_grammar_free(ctx->grammar);
}
if (ctx->smpl)
llama_sampler_dry_free(ctx->smpl);
if (ctx->adapt_p_ctx)
llama_free_adaptive_p(ctx->adapt_p_ctx);
if (ctx->rbudget)
common_reasoning_budget_free(ctx->rbudget);
delete ctx;
}
static void llama_grammar_reset(common_sampler * ctx) {
if (!ctx->grammar) {
return;
}
std::vector<const char*> trigger_patterns_c;
trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size());
for (auto& trigger_pattern : ctx->grammar->trigger_patterns) {
trigger_patterns_c.push_back(trigger_pattern.pattern.c_str());
}
auto* grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(),
ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
llama_grammar_free_impl(ctx->grammar);
ctx->grammar = grammar_new;
}
void common_sampler_reset(common_sampler * ctx) {
ctx->prev.clear();
llama_sampler_dry_reset(ctx->smpl);
}
void common_sampler_review(common_sampler * ctx, const size_t n_unsent, const bool rewind_status) {
if (ctx->adapt_p_ctx != nullptr) {
llama_review_adaptive_p(ctx->adapt_p_ctx, n_unsent, rewind_status);
}
}
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = std::random_device{}();
}
ctx->rng.seed(seed);
}
void common_sampler_clone(common_sampler * src, common_sampler * dst) {
dst->params = src->params;
dst->mirostat_mu = src->mirostat_mu;
dst->n_valid = src->n_valid;
dst->rng = src->rng;
dst->server_biases = src->server_biases;
if (dst->grammar) {
llama_grammar_free(dst->grammar);
dst->grammar = nullptr;
}
if (src->grammar) {
dst->grammar_root = src->grammar_root;
dst->grammar_str = src->grammar_str;
dst->grammar = llama_grammar_copy(src->grammar);
}
dst->prev = src->prev;
if (dst->smpl) {
llama_sampler_dry_free(dst->smpl);
dst->smpl = nullptr;
}
if (src->smpl) {
dst->smpl = llama_sampler_dry_clone(src->smpl);
}
if (dst->adapt_p_ctx) {
llama_free_adaptive_p(dst->adapt_p_ctx);
dst->adapt_p_ctx = nullptr;
}
if (src->adapt_p_ctx) {
dst->adapt_p_ctx = llama_clone_adaptive_p(src->adapt_p_ctx);
}
if (dst->rbudget) {
common_reasoning_budget_free(dst->rbudget);
dst->rbudget = nullptr;
}
if (src->rbudget) {
dst->rbudget = common_reasoning_budget_clone(src->rbudget);
}
}
llama_token llama_sampling_last(common_sampler * ctx) {
return ctx->prev.back();
}
std::string llama_sampling_prev_str(common_sampler * ctx_sampling, llama_context * ctx_main, int n) {
const int size = ctx_sampling->prev.size();
n = std::min(n, size);
std::string result;
for (int i = size - n; i < size; i++) {
result += common_token_to_piece(ctx_main, ctx_sampling->prev[i]);
}
return result;
}
std::string llama_sampling_print(const common_params_sampling & params) {
char result[1024];
snprintf(result, sizeof(result),
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n"
"\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f\n"
"\tadaptive_target = %.2f, adaptive_decay = %.2f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau,
params.xtc_probability, params.xtc_threshold, params.top_n_sigma,
params.adaptive_target, params.adaptive_decay);
return std::string(result);
}
std::string llama_sampling_order_print(const common_params_sampling & params) {
std::string result = "CFG -> Penalties ";
if (params.mirostat == 0) {
for (auto sampler_type : params.samplers_sequence) {
const auto sampler_type_name = llama_sampling_type_to_str(sampler_type);
if (!sampler_type_name.empty()) {
result += "-> " + sampler_type_name + " ";
}
}
} else {
result += "-> mirostat ";
}
return result;
}
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
switch (sampler_type) {
case llama_sampler_type::DRY: return "dry";
case llama_sampler_type::TOP_K: return "top_k";
case llama_sampler_type::TFS_Z: return "tfs_z";
case llama_sampler_type::TYPICAL_P: return "typical_p";
case llama_sampler_type::TOP_P: return "top_p";
case llama_sampler_type::MIN_P: return "min_p";
case llama_sampler_type::TEMPERATURE: return "temperature";
case llama_sampler_type::XTC : return "xtc";
case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma";
case llama_sampler_type::ADAPTIVE_P : return "adaptive_p";
default : return "";
}
}
std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) {
std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
{"dry", llama_sampler_type::DRY},
{"top_k", llama_sampler_type::TOP_K},
{"top_p", llama_sampler_type::TOP_P},
{"typical_p", llama_sampler_type::TYPICAL_P},
{"min_p", llama_sampler_type::MIN_P},
{"tfs_z", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
{"top_n_sigma", llama_sampler_type::TOP_N_SIGMA},
{"temperature", llama_sampler_type::TEMPERATURE},
{"adaptive_p", llama_sampler_type::ADAPTIVE_P},
};
std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
{"dry", llama_sampler_type::DRY},
{"top-k", llama_sampler_type::TOP_K},
{"top-p", llama_sampler_type::TOP_P},
{"nucleus", llama_sampler_type::TOP_P},
{"typical-p", llama_sampler_type::TYPICAL_P},
{"typical", llama_sampler_type::TYPICAL_P},
{"min-p", llama_sampler_type::MIN_P},
{"tfs-z", llama_sampler_type::TFS_Z},
{"tfs", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
{"top-n-sigma", llama_sampler_type::TOP_N_SIGMA},
{"temp", llama_sampler_type::TEMPERATURE},
{"adaptive-p", llama_sampler_type::ADAPTIVE_P},
};
std::vector<llama_sampler_type> sampler_types;
sampler_types.reserve(names.size());
for (const auto & name : names)
{
auto sampler_item = sampler_canonical_name_map.find(name);
if (sampler_item != sampler_canonical_name_map.end())
{
sampler_types.push_back(sampler_item->second);
}
else
{
if (allow_alt_names)
{
sampler_item = sampler_alt_name_map.find(name);
if (sampler_item != sampler_alt_name_map.end())
{
sampler_types.push_back(sampler_item->second);
}
}
}
}
return sampler_types;
}
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) {
std::unordered_map<char, llama_sampler_type> sampler_name_map {
{'d', llama_sampler_type::DRY},
{'k', llama_sampler_type::TOP_K},
{'p', llama_sampler_type::TOP_P},
{'y', llama_sampler_type::TYPICAL_P},
{'m', llama_sampler_type::MIN_P},
{'f', llama_sampler_type::TFS_Z},
{'x', llama_sampler_type::XTC},
{'n', llama_sampler_type::TOP_N_SIGMA},
{'t', llama_sampler_type::TEMPERATURE},
{'w', llama_sampler_type::ADAPTIVE_P},
};
std::vector<llama_sampler_type> sampler_types;
sampler_types.reserve(names_string.size());
for (const auto & c : names_string) {
const auto sampler_item = sampler_name_map.find(c);
if (sampler_item != sampler_name_map.end()) {
sampler_types.push_back(sampler_item->second);
}
}
return sampler_types;
}
static void sampler_queue(
struct llama_context* ctx_main,
const common_params_sampling& params,
common_sampler * ctx_sampling,
llama_token_data_array& cur_p,
size_t min_keep) {
const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent;
const int32_t top_k = params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const float xtc_probability = params.xtc_probability;
const float xtc_threshold = params.xtc_threshold;
const float top_n_sigma = params.top_n_sigma;
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
bool use_adaptive_p = false; for (auto sampler_type : samplers_sequence) {
switch (sampler_type) {
case llama_sampler_type::DRY : llama_sample_dry (ctx_main, ctx_sampling->smpl, &cur_p); break;
case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break;
case llama_sampler_type::TOP_N_SIGMA: llama_sample_top_n_sigma(ctx_main, &cur_p, top_n_sigma); break;
case llama_sampler_type::DIST : llama_sample_dist (ctx_main, &cur_p); break;
case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
float dynatemp_max = std::max(0.0f, temp + dynatemp_range);
llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent);
} else {
llama_sample_temp(ctx_main, &cur_p, temp);
}
break;
case llama_sampler_type::ADAPTIVE_P: use_adaptive_p = ctx_sampling->adapt_p_ctx != nullptr; break;
default : break;
}
}
if (use_adaptive_p) {
llama_sample_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
}
}
static bool grammar_should_apply(struct common_sampler * gsmpl) {
if (!gsmpl->grammar) {
return false;
}
if (!gsmpl->rbudget) {
return true;
}
if (gsmpl->params.grammar_lazy) {
const auto state = common_reasoning_budget_get_state(gsmpl->rbudget);
return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
}
return true;
}
static llama_token llama_sampling_sample_impl(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
bool grammar_first) {
const common_params_sampling & params = ctx_sampling->params;
const float temp = params.temp;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const float adaptive_target = params.adaptive_target;
std::vector<float> original_logits;
llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, grammar_first, &original_logits);
llama_token_data_array & cur_p = ctx_sampling->cur_p;
if (ctx_sampling->grammar != NULL && !grammar_first) {
GGML_ASSERT(!original_logits.empty());
}
auto & rbudget = ctx_sampling->rbudget;
llama_token id = 0;
float * logits = llama_get_logits_ith(ctx_main, idx);
common_reasoning_budget_apply(rbudget, &cur_p);
if (ctx_sampling->grammar != NULL && grammar_first && grammar_should_apply(ctx_sampling)) {
llama_grammar_apply(ctx_sampling->grammar, ctx_main, &cur_p);
}
if (temp < 0.0) {
llama_sample_softmax(ctx_main, &cur_p);
id = cur_p.data[0].id;
} else if (temp == 0.0) {
id = llama_sample_token_greedy(ctx_main, &cur_p);
} else {
if (mirostat == 1) {
const int mirostat_m = 100;
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
} else if (mirostat == 2) {
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else if (adaptive_target >= 0.0f && ctx_sampling->adapt_p_ctx!=nullptr) {
llama_prep_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep));
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
} else {
size_t min_keep = std::max(1, params.min_keep);
sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep);
id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng);
}
}
if (grammar_first || !grammar_should_apply(ctx_sampling)) {
return id;
}
if (ctx_sampling->grammar != NULL && !grammar_first && grammar_should_apply(ctx_sampling)) {
float * logits = llama_get_logits_ith(ctx_main, idx);
llama_token_data single_token_data = {id, logits[id], 0.0f};
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };
llama_grammar_apply(ctx_sampling->grammar, ctx_main, &single_token_data_array);
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (!is_valid) {
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, common_token_to_piece(ctx_main, id).c_str());
std::copy(original_logits.begin(), original_logits.end(), logits);
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true);
}
}
ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
return id;
}
static llama_token_data_array llama_sampling_prepare_impl(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
bool grammar_first,
std::vector<float> * original_logits) {
const common_params_sampling & params = ctx_sampling->params;
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
float * logits = llama_get_logits_ith(ctx_main, idx);
if (ctx_sampling->grammar != NULL && !grammar_first) {
GGML_ASSERT(original_logits != NULL);
*original_logits = {logits, logits + n_vocab};
}
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}
if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}
if (ctx_sampling->elb_states.size() > ctx_sampling->elb_idx) {
common_expiring_logit_bias_apply(ctx_sampling, logits);
}
cur.resize(n_vocab);
if ((ctx_sampling->server_biases != nullptr) && (ctx_sampling->server_biases->size() == n_vocab)) {
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id] + ctx_sampling->server_biases->at(token_id), 0.0f};
}
} else {
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
}
ctx_sampling->cur_p = { cur.data(), cur.size(), false };
llama_token_data_array & cur_p = ctx_sampling->cur_p;
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}
if (grammar_first && ctx_sampling->grammar != NULL) {
llama_grammar_apply(ctx_sampling->grammar, ctx_main, &cur_p);
}
return cur_p;
}
llama_token common_sampler_sample_legacy(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}
llama_token common_sampler_sample(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
const int idx,
bool grammar_first) {
return llama_sampling_sample_impl(ctx_sampling, ctx_main, nullptr, idx, grammar_first);
}
llama_token_data_array llama_sampling_prepare(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
bool grammar_first,
std::vector<float> * original_logits) {
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, grammar_first, original_logits);
}
void common_sampler_accept(
struct common_sampler * ctx_sampling,
struct llama_context * ctx_main,
llama_token token,
bool is_generated) {
if (ctx_sampling->prev.size() > 0) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
}
ctx_sampling->prev.push_back(token);
const auto accept_grammar = is_generated && grammar_should_apply(ctx_sampling);
if (ctx_sampling->rbudget && is_generated) {
common_reasoning_budget_accept(ctx_sampling->rbudget, token);
}
if (ctx_sampling->grammar && accept_grammar) {
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, token);
}
if (ctx_sampling->smpl) {
llama_sampler_dry_accept(ctx_sampling->smpl, token);
}
if (ctx_sampling->elb_states.size() > ctx_sampling->elb_idx) {
common_expiring_logit_bias_accept(ctx_sampling, ctx_main);
}
}
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
auto * res = &gsmpl->cur_p;
if (do_sort && !res->sorted) {
const llama_token id = res->data[res->selected].id;
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.p > b.p;
});
for (size_t i = 0; i < res->size; ++i) {
if (res->data[i].id == id) {
res->selected = i;
break;
}
}
res->sorted = true;
}
return res;
}
std::vector<llama_token> llama_sampling_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft) {
std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i;
}
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result;
result.reserve(idxs.size());
size_t i = 0;
for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
gsmpl->drafted_text += common_token_to_piece(ctx, id, true);
common_sampler_accept(gsmpl, ctx, id, true);
result.push_back(id);
if (draft[i] != id) {
break;
}
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
gsmpl->drafted_text += common_token_to_piece(ctx, id, true);
common_sampler_accept(gsmpl, ctx, id, true);
result.push_back(id);
}
return result;
}
static void elb_print(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) {
#undef X
#define X(T, MEMBER, DV, PRECAST) #MEMBER,
static const std::vector<std::string> names = { X_COMMON_PARAMS_SAMPLING };
#undef X
#define X(T, MEMBER, DV, PRECAST) if (std::abs(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM]) > 0.0f) \
{ LLAMA_LOG_DEBUG("%s: %s = %f\n", __func__, names[SPARAMS_ ## MEMBER ## _ENUM].c_str(), float(A_DOT_B(sparams, MEMBER))); }
X_COMMON_PARAMS_SAMPLING
}
static void elb_add(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) {
#undef X
#define X(T, MEMBER, _, PRECAST) A_DOT_B(sparams, MEMBER) += static_cast<T>(PRECAST(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM]));
X_COMMON_PARAMS_SAMPLING
}
static void elb_sub(common_params_sampling& sparams, const common_params_sampling::elb_param::elb_entry& entry) {
#undef X
#define X(T, MEMBER, _, PRECAST) A_DOT_B(sparams, MEMBER) -= static_cast<T>(PRECAST(entry.addsubs[SPARAMS_ ## MEMBER ## _ENUM]));
X_COMMON_PARAMS_SAMPLING
}
void common_expiring_logit_bias_apply(struct common_sampler* ctx_sampling, float* logits) {
auto index_first_inactive = [](auto countup, auto& tokens) {
return std::distance(
tokens.begin(),
std::upper_bound(tokens.begin(), tokens.end(), countup, [](const auto& countup, const auto& token) {
return countup > token.duration;
})
);
};
const auto& elb = ctx_sampling->elb_states[ctx_sampling->elb_idx];
std::string combined_text;
const std::string* search_window = &combined_text;
if (!ctx_sampling->drafted_text.empty()) {
combined_text = ctx_sampling->to_generated_text != nullptr ? (
ctx_sampling->to_generated_text->substr(std::max(0, int32_t(ctx_sampling->to_generated_text->length()) - elb.max_cond_len))
) : "" + ctx_sampling->drafted_text;
} else if (ctx_sampling->to_generated_text != nullptr) {
search_window = ctx_sampling->to_generated_text;
}
if (!search_window->empty() && !elb.other_tokens.empty() && (elb.other_tokens.front().duration > elb.countup)) {
const auto ifi = index_first_inactive(elb.countup, elb.other_tokens);
for (size_t j = 0; j < ifi; ++j) {
const auto& [id, bias, _, cond] = elb.other_tokens[j];
if (string_ends_with(*search_window, cond)) {
logits[id] += bias;
}
}
}
if (!elb.first_tokens.empty() && (elb.first_tokens.front().duration > elb.countup)) {
const auto ifi = index_first_inactive(elb.countup, elb.first_tokens);
if (search_window->empty()) {
for (size_t j = 0; j < ifi; ++j) {
logits[elb.first_tokens[j].id] += elb.first_tokens[j].bias;
}
} else {
for (size_t j = 0; j < ifi; ++j) {
const auto& [id, bias, _, cond] = elb.first_tokens[j];
if (!string_ends_with(*search_window, cond)) {
logits[id] += bias;
}
}
}
}
for (auto& entry: ctx_sampling->params.elb_params[ctx_sampling->elb_idx].entries) {
if (!entry.biases.empty()) {
continue; }
for (size_t j = 0; j < entry.phrases.size(); ++j) {
const auto& phrase = entry.phrases[j];
if (phrase.empty()) {
if (elb.countup == 0) {
LLAMA_LOG_DEBUG("%s: before add\n", __func__);
elb_print(ctx_sampling->params, entry);
elb_add(ctx_sampling->params, entry);
entry.addflags[j] = true;
LLAMA_LOG_DEBUG("%s: after add\n", __func__);
elb_print(ctx_sampling->params, entry);
} else if (elb.countup == entry.duration) {
LLAMA_LOG_DEBUG("%s: before sub\n", __func__);
elb_print(ctx_sampling->params, entry);
elb_sub(ctx_sampling->params, entry);
entry.addflags[j] = false;
LLAMA_LOG_DEBUG("%s: after sub\n", __func__);
elb_print(ctx_sampling->params, entry);
}
continue; }
size_t count = 0;
auto pos = ctx_sampling->to_generated_text->find(phrase, entry.posi[j]);
while (pos != std::string::npos) {
LLAMA_LOG_DEBUG("%s: found %s @ %zu\n", __func__, phrase.c_str(), pos);
++count;
pos = ctx_sampling->to_generated_text->find(phrase, pos + phrase.length());
}
entry.posi[j] = std::max(0, int32_t(ctx_sampling->to_generated_text->length()) - int32_t(phrase.length()) + 1);
if (count % 2 == 1) {
LLAMA_LOG_DEBUG("%s: before\n", __func__);
elb_print(ctx_sampling->params, entry);
(entry.addflags[j] ? elb_sub : elb_add)(ctx_sampling->params, entry);
entry.addflags[j] = !entry.addflags[j];
LLAMA_LOG_DEBUG("%s: after\n", __func__);
elb_print(ctx_sampling->params, entry);
}
}
}
}
void common_expiring_logit_bias_accept(struct common_sampler* ctx_sampling, struct llama_context * ctx_main) {
if (ctx_sampling->to_generated_text == nullptr) {
return;
}
auto idx = ctx_sampling->elb_idx;
auto& elb = ctx_sampling->elb_states[idx];
if ((elb.delay > ++elb.countup) || (elb.search_word_len == 0)) {
return;
}
const std::string window = ctx_sampling->to_generated_text->substr(std::min(
ctx_sampling->to_generated_text->length(),
ctx_sampling->elb_search_pos)) + common_token_to_piece(ctx_main, ctx_sampling->prev.back(), true);
size_t pos = 0;
if (string_is_found(window, elb.jumpword, pos)) {
LLAMA_LOG_DEBUG("%s: found %s in %s @ %zu\n", __func__, string_unescape(elb.jumpword).c_str(), string_unescape(window).c_str(), pos);
pos += ctx_sampling->elb_search_pos + elb.jumpword.length();
ctx_sampling->elb_idx = elb.jump_idx;
} else if (string_is_found(window, elb.exitword, pos)) {
LLAMA_LOG_DEBUG("%s: found %s in %s @ %zu\n", __func__, string_unescape(elb.exitword).c_str(), string_unescape(window).c_str(), pos);
pos += ctx_sampling->elb_search_pos + elb.exitword.length();
++ctx_sampling->elb_idx;
} else {
ctx_sampling->elb_search_pos += std::max(0, int32_t(window.length()) - int32_t(elb.search_word_len) + 1);
return;
}
ctx_sampling->elb_search_pos = pos + 1;
for (auto& entry: ctx_sampling->params.elb_params[idx].entries) {
for (const auto addflag: entry.addflags) {
if (addflag) {
LLAMA_LOG_DEBUG("%s: before\n", __func__);
elb_print(ctx_sampling->params, entry);
elb_sub(ctx_sampling->params, entry);
LLAMA_LOG_DEBUG("%s: after\n", __func__);
elb_print(ctx_sampling->params, entry);
}
}
}
for (auto& entry: ctx_sampling->params.elb_params[ctx_sampling->elb_idx].entries) {
std::fill(entry.posi.begin(), entry.posi.end(), pos);
}
}
template <>
json common_grammar_trigger::to_json() const {
json out{
{"type", (int)type},
{"value", value},
};
if (type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out["token"] = (int)token;
}
return out;
}
template <>
common_grammar_trigger common_grammar_trigger::from_json(const json& in) {
common_grammar_trigger out;
out.type = (common_grammar_trigger_type)in.at("type").get<int>();
out.value = in.at("value").get<std::string>();
if (out.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out.token = (llama_token)in.at("token").get<int>();
}
return out;
}
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
__attribute__((target("avx2")))
static bool common_sampler_speculative_top1_avx2(const float * logits, const int n_vocab, int & best_id, float & max_val) {
if (n_vocab < 8) {
return false;
}
__m256 max_v = _mm256_loadu_ps(logits);
__m256i id_v = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
const __m256i step = _mm256_set1_epi32(8);
__m256i cur_id = _mm256_add_epi32(id_v, step);
int i = 8;
for (; i + 7 < n_vocab; i += 8) {
const __m256 x = _mm256_loadu_ps(logits + i);
const __m256 gt_max = _mm256_cmp_ps(x, max_v, _CMP_GT_OQ);
max_v = _mm256_blendv_ps(max_v, x, gt_max);
id_v = _mm256_blendv_epi8(id_v, cur_id, _mm256_castps_si256(gt_max));
cur_id = _mm256_add_epi32(cur_id, step);
}
alignas(32) float max_buf[8];
alignas(32) int id_buf[8];
_mm256_store_ps(max_buf, max_v);
_mm256_store_si256((__m256i *) id_buf, id_v);
best_id = id_buf[0];
max_val = max_buf[0];
for (int j = 1; j < 8; ++j) {
if (max_buf[j] > max_val) {
max_val = max_buf[j]; best_id = id_buf[j];
}
}
for (; i < n_vocab; ++i) {
if (logits[i] > max_val) {
max_val = logits[i]; best_id = i;
}
}
return true;
}
__attribute__((target("avx2,fma")))
static inline __m256 v_expf(__m256 x) {
const __m256 r = _mm256_set1_ps(0x1.8p23f);
const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
const __m256 n = _mm256_sub_ps(z, r);
const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
const __m256 k = _mm256_castsi256_ps(
_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
const __m256i c = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(126), _CMP_GT_OQ));
const __m256 u = _mm256_mul_ps(b, b);
const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
_mm256_set1_ps(0x1.573e2ep-5f)), u,
_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
_mm256_set1_ps(0x1.fffdb6p-2f))),
u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
return _mm256_fmadd_ps(j, k, k);
const __m256i g = _mm256_and_si256(
_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
_mm256_set1_epi32(0x82000000u));
const __m256 s1 =
_mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
const __m256i d = _mm256_castps_si256(
_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
_mm256_set1_ps(192), _CMP_GT_OQ));
return _mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
_mm256_andnot_ps(
_mm256_castsi256_ps(d),
_mm256_or_ps(
_mm256_and_ps(_mm256_castsi256_ps(c),
_mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
_mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
}
__attribute__((target("avx2")))
static inline float hsum_float_4(__m128 x) {
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
x = _mm_add_ss(x, _mm_movehdup_ps(x));
return _mm_cvtss_f32(x);
}
__attribute__((target("avx2")))
static inline float hsum_float_8(__m256 x) {
return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1)));
}
__attribute__((target("avx2,fma")))
static float prob_avx2(int n, const float * logits, float max_val) {
float sumf = 0;
int i = 0;
if (n >= 8) {
auto sum_v = _mm256_setzero_ps();
auto max_v = _mm256_set1_ps(max_val);
for (; i < n - 7; i += 8) {
auto x = _mm256_loadu_ps(logits + i);
auto exp_x = v_expf(_mm256_sub_ps(x, max_v));
sum_v = _mm256_add_ps(sum_v, exp_x);
}
sumf = hsum_float_8(sum_v);
}
for (; i < n; ++i) {
sumf += expf(logits[i] - max_val);
}
return 1.0f/sumf;
}
#endif
static float prob_scalar(int n, const float * logits, float max_val) {
double sum_exp = 0.0;
for (int i = 0; i < n; ++i) {
sum_exp += exp((double)(logits[i] - max_val));
}
return (float)(1./sum_exp);
}
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, float * out_prob) {
GGML_UNUSED(gsmpl);
float * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
int best_id = 0;
float max_val = logits[0];
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
static const bool has_avx2 = __builtin_cpu_supports("avx2");
if (has_avx2 && common_sampler_speculative_top1_avx2(logits, n_vocab, best_id, max_val)) {
if (out_prob) {
static const bool has_fma = __builtin_cpu_supports("fma");
if (has_fma) {
*out_prob = prob_avx2(n_vocab, logits, max_val);
} else {
*out_prob = prob_scalar(n_vocab, logits, max_val);
}
}
return best_id;
}
#endif
for (int i = 1; i < n_vocab; ++i) {
if (logits[i] > max_val) {
max_val = logits[i];
best_id = i;
}
}
if (out_prob) {
*out_prob = prob_scalar(n_vocab, logits, max_val);
}
return best_id;
}