slm_ikllama_sys 0.1.1

ik_llama.cpp rust sys bindings
#include "reasoning-budget.h"
#include "common.h"
#include "unicode.h"

#include "log.h"

#include <cmath>
#include <cstdint>
#include <string>
#include <vector>

struct token_matcher {
    std::vector<llama_token> tokens;
    size_t pos = 0;

    bool advance(llama_token token) {
        if (tokens.empty()) {
            return false;
        }

        if (token == tokens[pos]) {
            pos++;
            if (pos >= tokens.size()) {
                pos = 0;
                return true;
            }
        } else {
            pos = 0;
            if (token == tokens[0]) {
                pos = 1;
            }
        }
        return false;
    }

    void reset() { pos = 0; }
};

struct common_reasoning_budget_ctx {
    const llama_vocab * vocab;

    token_matcher start_matcher;
    token_matcher end_matcher;
    std::vector<llama_token> forced_tokens;

    int32_t budget;           // maximum tokens in reasoning block
    int32_t remaining;        // tokens remaining in budget

    common_reasoning_budget_state state;

    // for forcing
    size_t force_pos;         // next position in forced_tokens to force
};

static const char * common_reasoning_budget_name(const common_reasoning_budget_ctx * /*smpl*/) {
    return "reasoning-budget";
}

static void common_reasoning_budget_accept(common_reasoning_budget_ctx * smpl, llama_token token) {
    auto * ctx = (common_reasoning_budget_ctx *)smpl;

    switch (ctx->state) {
    case REASONING_BUDGET_IDLE:
    {
        if (ctx->start_matcher.advance(token)) {
            ctx->state = REASONING_BUDGET_COUNTING;
            ctx->remaining = ctx->budget;
            LOG_DBG("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);

            if (ctx->remaining <= 0) {
                ctx->state = REASONING_BUDGET_FORCING;
                ctx->force_pos = 0;
                LOG_DBG("reasoning-budget: budget=0, forcing immediately\n");
            }
        }
        break;
    }
    case REASONING_BUDGET_COUNTING:
    case REASONING_BUDGET_WAITING_UTF8:
    {
        if (ctx->end_matcher.advance(token)) {
            ctx->state = REASONING_BUDGET_DONE;
            LOG_DBG("reasoning-budget: deactivated (natural end)\n");
            break;
        }

        bool utf8_complete = true;
        if (ctx->vocab != nullptr) {
            const std::string piece = common_token_to_piece(ctx->vocab, token, false);
            utf8_complete = common_utf8_is_complete(piece);
        }

        if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
            if (utf8_complete) {
                ctx->state = REASONING_BUDGET_FORCING;
                ctx->force_pos = 0;
                ctx->end_matcher.reset();
                LOG_DBG("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
            }
        } else if (ctx->state == REASONING_BUDGET_COUNTING) {
            ctx->remaining--;
            if (ctx->remaining <= 0) {
                if (utf8_complete) {
                    ctx->state = REASONING_BUDGET_FORCING;
                    ctx->force_pos = 0;
                    ctx->end_matcher.reset();
                    LOG_DBG("reasoning-budget: budget exhausted, forcing end sequence\n");
                } else {
                    ctx->state = REASONING_BUDGET_WAITING_UTF8;
                    ctx->end_matcher.reset();
                    LOG_DBG("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
                }
            }
        }
        break;
    }
    case REASONING_BUDGET_FORCING:
        ctx->force_pos++;
        if (ctx->force_pos >= ctx->forced_tokens.size()) {
            ctx->state = REASONING_BUDGET_DONE;
            LOG_DBG("reasoning-budget: forced sequence complete, done\n");
        }
        break;
    case REASONING_BUDGET_DONE:
        // Re-arm on a new start tag: some models emit multiple <think> blocks
// per response, and each should get a fresh budget window.
        if (ctx->start_matcher.advance(token)) {
            ctx->state = REASONING_BUDGET_COUNTING;
            ctx->remaining = ctx->budget;
            ctx->end_matcher.reset();
            LOG_DBG("reasoning-budget: re-activated on new start tag, budget=%d tokens\n", ctx->budget);

            if (ctx->remaining <= 0) {
                ctx->state = REASONING_BUDGET_FORCING;
                ctx->force_pos = 0;
                LOG_DBG("reasoning-budget: budget=0, forcing immediately\n");
            }
        }
        break;
    }
}

static void common_reasoning_budget_apply(struct common_reasoning_budget_ctx * smpl, llama_token_data_array * cur_p) {
    auto * ctx = (common_reasoning_budget_ctx *)smpl;
    if (!ctx) {
        return;
    }
    if (ctx->state != REASONING_BUDGET_FORCING) {
        // passthrough — don't modify logits
        return;
    }

    if (ctx->force_pos >= ctx->forced_tokens.size()) {
        return;
    }

    const llama_token forced = ctx->forced_tokens[ctx->force_pos];

    // set all logits to -inf except the forced token
    for (size_t i = 0; i < cur_p->size; i++) {
        if (cur_p->data[i].id != forced) {
            cur_p->data[i].logit = -INFINITY;
        }
    }
}

static void common_reasoning_budget_reset(common_reasoning_budget_ctx * smpl) {
    auto * ctx = (common_reasoning_budget_ctx *)smpl;
    ctx->state = REASONING_BUDGET_IDLE;
    ctx->remaining = ctx->budget;
    ctx->start_matcher.reset();
    ctx->end_matcher.reset();
    ctx->force_pos = 0;
}

// forward declaration for use in clone
static struct common_reasoning_budget_ctx * common_reasoning_budget_init_state(
    const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
    const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
    int32_t budget, common_reasoning_budget_state initial_state);

static struct common_reasoning_budget_ctx * common_reasoning_budget_clone(const struct common_reasoning_budget_ctx * smpl) {
    const auto * ctx = (const common_reasoning_budget_ctx *)smpl;
    return new common_reasoning_budget_ctx(*ctx);
}

static void common_reasoning_budget_free(struct common_reasoning_budget_ctx * smpl) {
    delete (common_reasoning_budget_ctx *)smpl;
}

//static struct llama_sampler_i common_reasoning_budget_i = {
//    /* .name              = */ common_reasoning_budget_name,
//    /* .accept            = */ common_reasoning_budget_accept,
//    /* .apply             = */ common_reasoning_budget_apply,
//    /* .reset             = */ common_reasoning_budget_reset,
//    /* .clone             = */ common_reasoning_budget_clone,
//    /* .free              = */ common_reasoning_budget_free,
//    /* .backend_init      = */ nullptr,
//    /* .backend_accept    = */ nullptr,
//    /* .backend_apply     = */ nullptr,
//    /* .backend_set_input = */ nullptr,
//};

static common_reasoning_budget_ctx * common_reasoning_budget_init_state(
    const struct llama_vocab * vocab,
    const std::vector<llama_token> & start_tokens,
    const std::vector<llama_token> & end_tokens,
    const std::vector<llama_token> & forced_tokens,
    int32_t                                budget,
    common_reasoning_budget_state          initial_state) {
    // promote COUNTING with budget <= 0 to FORCING
    if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
        initial_state = REASONING_BUDGET_FORCING;
    }

    return
        /* .ctx   = */ new common_reasoning_budget_ctx{
        /* .vocab         = */ vocab,
        /* .start_matcher = */ { start_tokens, 0 },
        /* .end_matcher   = */ { end_tokens, 0 },
        /* .forced_tokens = */ forced_tokens,
        /* .budget        = */ budget,
        /* .remaining     = */ budget,
        /* .state         = */ initial_state,
        /* .force_pos     = */ 0,
    };

}

struct common_reasoning_budget_ctx *  common_reasoning_budget_init(
    const struct llama_vocab * vocab,
    const std::vector<llama_token> & start_tokens,
    const std::vector<llama_token> & end_tokens,
    const std::vector<llama_token> & forced_tokens,
    int32_t                          budget,
    common_reasoning_budget_state    initial_state) {
    return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
}

common_reasoning_budget_state common_reasoning_budget_get_state(const common_reasoning_budget_ctx * smpl) {
    if (!smpl) {
        return REASONING_BUDGET_IDLE;
    }
    return ((const common_reasoning_budget_ctx *)smpl)->state;
}