llama-cpp-sys-4 0.3.1

Low Level Bindings to llama.cpp
Documentation
#include "llama-kv-cache-dsa.h"

#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-model.h"

#include <algorithm>
#include <cassert>

//
// llama_kv_cache_dsa
//

llama_kv_cache_dsa::llama_kv_cache_dsa(
        const llama_model & model,
                ggml_type   type_k,
                ggml_type   type_v,
                     bool   v_trans,
                     bool   offload,
                     bool   unified,
                 uint32_t   kv_size,
                 uint32_t   n_seq_max,
                 uint32_t   n_pad,
                 uint32_t   n_swa,
           llama_swa_type   swa_type,
    const layer_filter_cb & filter,
    const  layer_reuse_cb & reuse) :
    hparams_lid(model.hparams), n_stream(unified ? 1 : n_seq_max) {

    LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size);

    kv_mla = std::make_unique<llama_kv_cache>(
            model, model.hparams, type_k, type_v,
            v_trans, offload, unified, kv_size, n_seq_max, n_pad,
            n_swa, swa_type, filter, reuse);

    // we use llama_kv_cache for caching indexer keys
    // by hand-tweaking some hparams we fool it to create
    // indexer key cache tensors with correct dimensions
    // https://github.com/ggml-org/llama.cpp/pull/21149#discussion_r3015940823

    // DSA lightning indexer uses MQA with single key head
    std::fill(hparams_lid.n_head_kv_arr.begin(), hparams_lid.n_head_kv_arr.end(), 1);
    hparams_lid.n_embd_head_k_full = model.hparams.indexer_head_size;
    hparams_lid.rope_type          = LLAMA_ROPE_TYPE_NEOX;

    LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size);

    kv_lid = std::make_unique<llama_kv_cache>(
            model, hparams_lid, type_k, type_v,
            v_trans, offload, unified, kv_size, n_seq_max, n_pad,
            n_swa, swa_type, filter, reuse);
}

void llama_kv_cache_dsa::clear(bool data) {
    kv_mla->clear(data);
    kv_lid->clear(data);
}

bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
    bool res = true;

    res = res & kv_mla->seq_rm(seq_id, p0, p1);
    res = res & kv_lid->seq_rm(seq_id, p0, p1);

    return res;
}

void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
    kv_mla->seq_cp(seq_id_src, seq_id_dst, p0, p1);
    kv_lid->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}

void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) {
    kv_mla->seq_keep(seq_id);
    kv_lid->seq_keep(seq_id);
}

void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
    kv_mla->seq_add(seq_id, p0, p1, shift);
    kv_lid->seq_add(seq_id, p0, p1, shift);
}

void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
    kv_mla->seq_div(seq_id, p0, p1, d);
    kv_lid->seq_div(seq_id, p0, p1, d);
}

llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const {
    return kv_mla->seq_pos_min(seq_id);
}

llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const {
    return kv_mla->seq_pos_max(seq_id);
}

std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_dsa::memory_breakdown() const {
    std::map<ggml_backend_buffer_type_t, size_t> mb = kv_mla->memory_breakdown();
    for (const auto & buft_size : kv_lid->memory_breakdown()) {
        mb[buft_size.first] += buft_size.second;
    }
    return mb;
}

llama_memory_context_ptr llama_kv_cache_dsa::init_batch(
            llama_batch_allocr & balloc,
            uint32_t n_ubatch,
            bool embd_all) {
    GGML_UNUSED(embd_all);

    do {
        balloc.split_reset();

        std::vector<llama_ubatch> ubatches;
        while (true) {
            auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);

            if (ubatch.n_tokens == 0) {
                break;
            }

            ubatches.push_back(std::move(ubatch)); // NOLINT
        }

        if (balloc.get_n_used() < balloc.get_n_tokens()) {
            // failed to find a suitable split
            break;
        }

        auto sinfos_mla = kv_mla->prepare(ubatches);
        if (sinfos_mla.empty()) {
            break;
        }

        auto sinfos_lid = kv_lid->prepare(ubatches);
        if (sinfos_lid.empty()) {
            break;
        }

        assert(sinfos_mla.size() == sinfos_lid.size());

        return std::make_unique<llama_kv_cache_dsa_context>(
                this, std::move(sinfos_mla), std::move(sinfos_lid), std::move(ubatches));
    } while (false);

    return std::make_unique<llama_kv_cache_dsa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}

llama_memory_context_ptr llama_kv_cache_dsa::init_full() {
    return std::make_unique<llama_kv_cache_dsa_context>(this);
}

llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) {
    return std::make_unique<llama_kv_cache_dsa_context>(this, lctx, optimize);
}

bool llama_kv_cache_dsa::get_can_shift() const {
    return kv_mla->get_can_shift() &&
           kv_lid->get_can_shift() &&
           kv_mla->get_size() == kv_lid->get_size();
}

void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
    kv_mla->state_write(io, seq_id, flags);
    kv_lid->state_write(io, seq_id, flags);
}

void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
    kv_mla->state_read(io, seq_id, flags);
    kv_lid->state_read(io, seq_id, flags);
}

llama_kv_cache * llama_kv_cache_dsa::get_mla() const {
    return kv_mla.get();
}

llama_kv_cache * llama_kv_cache_dsa::get_lid() const {
    return kv_lid.get();
}

//
// llama_kv_cache_dsa_context
//

llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {}

llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(
        llama_kv_cache_dsa * kv) :
    ctx_mla(kv->get_mla()->init_full()),
    ctx_lid(kv->get_lid()->init_full()),
    status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) {
}

llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(
        llama_kv_cache_dsa * kv,
        llama_context * lctx,
        bool optimize) :
    ctx_mla(kv->get_mla()->init_update(lctx, optimize)),
    ctx_lid(kv->get_lid()->init_update(lctx, optimize)),
    status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) {
}

llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(
        llama_kv_cache_dsa * kv,
        slot_info_vec_t sinfos_mla,
        slot_info_vec_t sinfos_lid,
        std::vector<llama_ubatch> ubatches) :
    ubatches(std::move(ubatches)),
    // note: here we copy the ubatches. not sure if this is ideal
    ctx_mla(new llama_kv_cache_context(kv->get_mla(), std::move(sinfos_mla), this->ubatches)),
    ctx_lid(new llama_kv_cache_context(kv->get_lid(), std::move(sinfos_lid), this->ubatches)),
    status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) {
}

llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default;

bool llama_kv_cache_dsa_context::next() {
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

    ctx_mla->next();
    ctx_lid->next();

    if (++i_next >= ubatches.size()) {
        return false;
    }

    return true;
}

bool llama_kv_cache_dsa_context::apply() {
    assert(!llama_memory_status_is_fail(status));

    bool res = true;

    res = res & ctx_mla->apply();
    res = res & ctx_lid->apply();

    return res;
}

llama_memory_status llama_kv_cache_dsa_context::get_status() const {
    return status;
}

const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const {
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

    return ubatches[i_next];
}

const llama_kv_cache_context * llama_kv_cache_dsa_context::get_mla() const {
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

    return static_cast<const llama_kv_cache_context *>(ctx_mla.get());
}

const llama_kv_cache_context * llama_kv_cache_dsa_context::get_lid()  const {
    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);

    return static_cast<const llama_kv_cache_context *>(ctx_lid.get());
}