#pragma once
#include "llama.h"
#include "grammar-parser.h"
#include <string>
#include <vector>
#include <unordered_map>
enum class llama_sampler_type : char {
TOP_K = 'k',
TOP_P = 'p',
MIN_P = 'm',
TFS_Z = 'f',
TYPICAL_P = 'y',
TEMPERATURE = 't'
};
typedef struct llama_sampling_params {
int32_t n_prev = 64; int32_t n_probs = 0; int32_t min_keep = 0; int32_t top_k = 40; float top_p = 0.95f; float min_p = 0.05f; float tfs_z = 1.00f; float typical_p = 1.00f; float temp = 0.80f; float dynatemp_range = 0.00f; float dynatemp_exponent = 1.00f; int32_t penalty_last_n = 64; float penalty_repeat = 1.00f; float penalty_freq = 0.00f; float penalty_present = 0.00f; int32_t mirostat = 0; float mirostat_tau = 5.00f; float mirostat_eta = 0.10f; bool penalize_nl = false;
std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
llama_sampler_type::TFS_Z,
llama_sampler_type::TYPICAL_P,
llama_sampler_type::TOP_P,
llama_sampler_type::MIN_P,
llama_sampler_type::TEMPERATURE
};
std::string grammar;
std::string cfg_negative_prompt; float cfg_scale = 1.f;
std::unordered_map<llama_token, float> logit_bias;
std::vector<llama_token> penalty_prompt_tokens;
bool use_penalty_prompt_tokens = false;
} llama_sampling_params;
struct llama_sampling_context {
llama_sampling_params params;
float mirostat_mu;
llama_grammar * grammar;
grammar_parser::parse_state parsed_grammar;
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
};
#include "common.h"
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
void llama_sampling_free(struct llama_sampling_context * ctx);
void llama_sampling_reset(llama_sampling_context * ctx);
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
llama_token llama_sampling_last(llama_sampling_context * ctx);
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);
std::string llama_sampling_print(const llama_sampling_params & params);
std::string llama_sampling_order_print(const llama_sampling_params & params);
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0,
bool apply_grammar = true,
std::vector<float> * original_logits = nullptr);
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id,
bool apply_grammar);