slm_ikllama_sys 0.1.1

ik_llama.cpp rust sys bindings
#pragma once

#include "llama.h"
#include "llama-grammar.h"
#include "reasoning-budget.h"
#include <set>
#include <random>
#include <string>
#include <unordered_map>
#include <vector>

#define A_DOT_B(a, b) a.b

// sampler types
enum class llama_sampler_type : char {
    DRY         = 'd',
    TOP_K       = 'k',
    TOP_P       = 'p',
    MIN_P       = 'm',
    TFS_Z       = 'f',
    XTC         = 'x',
    TOP_N_SIGMA = 'n',
    TYPICAL_P   = 'y',
    TEMPERATURE = 't',
    ADAPTIVE_P  = 'w',
    DIST     = 's',
};

enum common_grammar_trigger_type {
    COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
    COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
    COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
};

struct common_grammar_trigger {
    common_grammar_trigger_type type;
    std::string value;
    llama_token token = LLAMA_TOKEN_NULL;

    // T can only be nlohmann::ordered_json
    template <class T> T to_json() const;
    template <class T> static common_grammar_trigger from_json(const T& in);
};


// Grammar type enumeration
enum common_grammar_type {
    COMMON_GRAMMAR_TYPE_NONE,           // no grammar set
    COMMON_GRAMMAR_TYPE_USER,           // user-provided GBNF (--grammar / "grammar" API field)
    COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT,  // auto-generated from JSON schema (--json-schema / "json_schema" API field)
    COMMON_GRAMMAR_TYPE_TOOL_CALLS,     // auto-generated by chat template parser for function calling
};

// Grammar variant struct with type and grammar string
struct common_grammar {
    common_grammar_type type = COMMON_GRAMMAR_TYPE_NONE;
    std::string grammar;

    // Default constructor - no grammar
    common_grammar() = default;

    // Constructor with type and grammar string
    common_grammar(common_grammar_type t, std::string g) : type(t), grammar(std::move(g)) {
        GGML_ASSERT(type != COMMON_GRAMMAR_TYPE_NONE || !grammar.empty());
    }

    // Check if a grammar is set
    bool empty() const { return type == COMMON_GRAMMAR_TYPE_NONE || grammar.empty(); }
};

// Returns the raw grammar string, or empty string if no grammar is set.
inline const std::string & common_grammar_value(const common_grammar & g) {
    return g.grammar;
}

// Returns true when the generation_prompt should be prefilled into the grammar sampler.
// Only output-format and tool-call grammars need prefill; user-supplied grammars must not be prefilled.
inline bool common_grammar_needs_prefill(const common_grammar & g) {
    return g.type == COMMON_GRAMMAR_TYPE_OUTPUT_FORMAT
        || g.type == COMMON_GRAMMAR_TYPE_TOOL_CALLS;
}


#define X_COMMON_PARAMS_SAMPLING                                 /*  \
    */  X( int32_t , min_keep            , 0     , std::round )  /*  0 = disabled, otherwise samplers should return at least min_keep tokens \
    */  X( int32_t , top_k               , 40    , std::round )  /*  <= 0 to use vocab size \
    */  X( float   , top_p               , 0.95f ,            )  /*  1.0 = disabled \
    */  X( float   , min_p               , 0.05f ,            )  /*  0.0 = disabled \
    */  X( float   , tfs_z               , 1.00f ,            )  /*  1.0 = disabled \
    */  X( float   , typical_p           , 1.00f ,            )  /*  1.0 = disabled \
    */  X( float   , temp                , 0.80f ,            )  /*  <= 0.0 to sample greedily, 0.0 to not output probabilities \
    */  X( float   , dynatemp_range      , 0.00f ,            )  /*  0.0 = disabled \
    */  X( float   , dynatemp_exponent   , 1.00f ,            )  /*  controls how entropy maps to temperature in dynamic temperature sampler \
    */  X( int32_t , penalty_last_n      , 64    , std::round )  /*  last n tokens to penalize (0 = disable penalty, -1 = context size) \
    */  X( float   , penalty_repeat      , 1.00f ,            )  /*  1.0 = disabled \
    */  X( float   , penalty_freq        , 0.00f ,            )  /*  0.0 = disabled \
    */  X( float   , penalty_present     , 0.00f ,            )  /*  0.0 = disabled \
    */  X( float   , dry_multiplier      , 0.0f  ,            )  /*  0.0 = disabled; DRY repetition penalty for tokens extending repetition: \
    */  X( float   , dry_base            , 1.75f ,            )  /*  0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) \
    */  X( int32_t , dry_allowed_length  , 2     , std::round )  /*  tokens extending repetitions beyond this receive penalty \
    */  X( int32_t , dry_penalty_last_n  , -1    , std::round )  /*  how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) \
    */  X( int32_t , mirostat            , 0     , std::round )  /*  0 = disabled, 1 = mirostat, 2 = mirostat 2.0 \
    */  X( float   , mirostat_tau        , 5.00f ,            )  /*  target entropy \
    */  X( float   , mirostat_eta        , 0.10f ,            )  /*  learning rate \
    */  X( float   , xtc_probability     , 0.0f  ,            )  /*  xtc probability \
    */  X( float   , xtc_threshold       , 1.0f  ,            )  /*  xtc threshold, disabled if > 0.5 \
    */  X( float   , top_n_sigma         , 0.0f  ,            )  /*  top-n-sigma \
    */  X( float   , adaptive_target     , -1.0f ,            )  /*  select tokens near this probability (valid range 0.0 to 1.0; <0 = disabled) \
    */  X( float   , adaptive_decay      , 0.90f ,            )  /*  decay rate for target adaptation over time. lower values -> faster but less stable adaptation. (valid range 0.0 to 1.0; ≤0 = no adaptation) \
    */  X( bool    , adaptive_updt_w_cur , false , std::round )  /*  update state with current probability \
    */

enum {
    #undef X
    #define X(T, MEMBER, DV, PRECAST) SPARAMS_ ## MEMBER ## _ENUM,
    X_COMMON_PARAMS_SAMPLING
};

// sampling parameters
typedef struct common_params_sampling {
    #undef X
    #define X(T, MEMBER, DV, _) T MEMBER = DV;
                X_COMMON_PARAMS_SAMPLING
    int32_t     n_prev                = 64;                 // number of previous tokens to remember
    int32_t     n_probs               = 0;                  // if greater than 0, output the probabilities of top n_probs tokens.
    int32_t     total_context_size    = 16840;
    bool        penalize_nl           = false;              // consider newlines as a repeatable token
    uint32_t    seed                  = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context

    std::vector<std::string> dry_sequence_breakers = { "\n", ":", "\"", "*" };     // default sequence breakers for DRY

    std::vector<llama_sampler_type> samplers_sequence = {
        llama_sampler_type::DRY,
        llama_sampler_type::TOP_K,
        llama_sampler_type::TFS_Z,
        llama_sampler_type::TYPICAL_P,
        llama_sampler_type::TOP_P,
        llama_sampler_type::MIN_P,
        llama_sampler_type::XTC,
        llama_sampler_type::TOP_N_SIGMA,
        llama_sampler_type::TEMPERATURE,
        llama_sampler_type::ADAPTIVE_P,
        llama_sampler_type::DIST,
    };


    //std::string grammar;  // optional BNF-like grammar to constrain sampling
    common_grammar              grammar;      // optional grammar constraint (user / output-format / tool-calls)
    bool                                grammar_lazy = false;
    std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
    std::set<llama_token>               preserved_tokens;
    // Classifier-Free Guidance
    // https://arxiv.org/abs/2306.17806
    std::string cfg_negative_prompt; // string to help guidance
    float       cfg_scale     = 1.f; // how strong is guidance

    std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

    // The assistant generation prompt already prefilled into the prompt.
    // Fed to the grammar sampler (to advance past pre-existing tokens) and used
    // to determine the reasoning budget sampler's initial state.
    // Only applied when the grammar is of output-format or tool-calls type.
    std::string generation_prompt;

    // reasoning budget sampler parameters
    // these are populated by the server/CLI based on chat template params
    int32_t                  reasoning_budget_tokens = -1;   // -1 = disabled, >= 0 = token budget
    std::vector<llama_token> reasoning_budget_start;           // start tag token sequence
    std::vector<llama_token> reasoning_budget_end;             // end tag token sequence
    std::vector<llama_token> reasoning_budget_forced;          // forced sequence (message + end tag)


    std::vector<llama_token> penalty_prompt_tokens;
    bool                     use_penalty_prompt_tokens = false;

    // expiring logit bias
    struct elb_param {
        struct elb_entry {
            std::vector<size_t>         posi;           // positions of phrases in generated text
            std::vector<float>          addsubs;        // add/modify then subtract/restore sampling parameters
            std::vector<bool>           addflags;       // true if added
            size_t                      max_phrase_len;
            std::vector<std::string>    phrases;
            std::vector<float>          biases;     // for each phrase, nth bias for nth token, extrapolate
            int32_t                     duration;   // bias duration, unless exitword matches
            bool                        is_range;   // has lower and upper biases
            bool operator == (const struct elb_entry& other) const {
                return (is_range == other.is_range)
                    && (duration == other.duration)
                    && (biases == other.biases)
                    && (phrases == other.phrases)
                    && (addflags == other.addflags)
                    && (addsubs == other.addsubs)
                    && (posi == other.posi);
            }
        };
        std::vector<struct elb_entry> entries;
        std::string                   exitword;     // move to next state if matched during generation
        std::string                   op;           // exitword operator
        bool operator == (const struct elb_param& other) const {
            return (op == other.op)
                && (exitword == other.exitword)
                && (entries == other.entries);
        }
    };
    std::vector<struct elb_param> elb_params;

} llama_sampling_params;

// general sampler context
// TODO: move to llama.h
struct common_sampler {
    // parameters that will be used for sampling
    common_params_sampling params;

    // mirostat sampler state
    float mirostat_mu;

    std::string grammar_str;
    std::string grammar_root;

    llama_grammar * grammar;

    // TODO: replace with ring-buffer
    std::vector<llama_token>      prev;
    std::vector<llama_token_data> cur;
    llama_sampler_dry* smpl;

    llama_sampler_adaptive_p * adapt_p_ctx;    // adaptive p sampler

    common_reasoning_budget_ctx * rbudget; // reasoning budget sampler

    size_t n_valid; // Number of correct top tokens with correct probabilities.

    llama_token_data_array cur_p; // current candidates

    std::mt19937 rng;

    std::vector<float>* server_biases;

    std::string  drafted_text;
    std::string* to_generated_text = nullptr;

    // expiring logit bias
    struct elb_state {
        struct elb_token {
            int32_t     id;
            float       bias;
            size_t      duration;
            std::string cond;       // bias activation condition
        };
        std::vector<struct elb_token> first_tokens;     // first token of each phrase
        std::vector<struct elb_token> other_tokens;
        std::string                   exitword;
        size_t                        countup;          // compare against duration
        size_t                        delay;            // to avoid early termination of positively biased phrases
        int32_t                       max_cond_len;
        std::string                   jumpword;
        size_t                        jump_idx;
        size_t                        search_word_len;
    };
    std::vector<struct elb_state> elb_states;
    size_t                        elb_idx;          // for elb_states
    size_t                        elb_search_pos;
};



// Create a new sampling context instance.
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);

void common_sampler_free(struct common_sampler * ctx);

// Reset the sampler context
// - clear prev tokens
// - reset grammar
void common_sampler_reset(common_sampler * ctx);

// Review stateful samplers
// - rewind internal states (maybe)
void common_sampler_review(common_sampler * ctx, const size_t n_unsent, const bool rewind_status);

// Set the sampler seed
void llama_sampling_set_rng_seed(struct common_sampler * ctx, uint32_t seed);

// Copy the sampler context
void common_sampler_clone(common_sampler * src, common_sampler * dst);

// Get the last sampled token
llama_token llama_sampling_last(common_sampler * ctx);

// Get a string representation of the last sampled tokens
std::string llama_sampling_prev_str(common_sampler * ctx_sampling, llama_context * ctx_main, int n);

// Print sampling parameters into a string
std::string llama_sampling_print(const common_params_sampling & params);

// Print sampling order into a string
std::string llama_sampling_order_print(const common_params_sampling & params);

std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);

std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);

// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
//       common_sampler_reset when a sequence ends
//
// required:
//  - ctx_main:     context to use for sampling
//  - ctx_sampling: sampling-specific context
//
// optional:
//  - ctx_cfg:      context to use for classifier-free guidance
//  - idx:          sample from llama_get_logits_ith(ctx, idx)
//
// returns:
//  - token:      sampled token
//  - candidates: vector of candidate tokens
//
llama_token common_sampler_sample_legacy(
        struct common_sampler * ctx_sampling,
        struct llama_context * ctx_main,
        struct llama_context * ctx_cfg,
        int idx = -1);

llama_token common_sampler_sample(
    struct common_sampler * ctx_sampling,
    struct llama_context * ctx_main,
    int idx = -1,
    bool grammar_first = false);

// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_prepare(
        struct common_sampler * ctx_sampling,
        struct llama_context * ctx_main,
        struct llama_context * ctx_cfg,
        int idx = 0,
        bool apply_grammar = true,
        std::vector<float> * original_logits = nullptr);

// if is_generated is true, the token is accepted by the sampling chain, the reasoning budget sampler, and the grammar sampler
void common_sampler_accept(
        struct common_sampler * ctx_sampling,
        struct llama_context * ctx_main,
        llama_token id,
        bool is_generated);

// returns at least 1 token, up to draft.size()
// access the internal list of current candidate tokens
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * ctx_sampling, bool do_sort = false);

std::vector<llama_token> llama_sampling_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<llama_token> & draft);

std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);

// Greedy argmax sampling for speculative drafting
llama_token common_sampler_sample_speculative(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, float * out_prob = nullptr);

void common_expiring_logit_bias_apply(struct common_sampler* ctx_sampling, float* logits);

void common_expiring_logit_bias_accept(struct common_sampler* ctx_sampling, struct llama_context * ctx_main);

llama_grammar* llama_sampler_init_llg(const llama_vocab* vocab,
    const char* grammar_kind, const char* grammar_data);