#pragma once
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cells.h"
#include "llama-memory.h"
#include <unordered_map>
#include <vector>
struct llama_cparams;
struct llama_hparams;
struct llama_model;
struct llama_context;
class llama_kv_cache : public llama_memory_i {
public:
struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
return ssrc.empty();
}
std::vector<uint32_t> ssrc;
std::vector<uint32_t> sdst;
};
struct slot_info {
using idx_vec_t = std::vector<uint32_t>;
uint32_t s0;
uint32_t s1;
std::vector<llama_seq_id> strm; std::vector<idx_vec_t> idxs;
uint32_t head() const {
GGML_ASSERT(idxs.size() == 1);
GGML_ASSERT(!idxs[0].empty());
return idxs[0][0];
}
void resize(size_t n) {
strm.resize(n);
idxs.resize(n);
}
size_t size() const {
GGML_ASSERT(idxs.size() == strm.size());
GGML_ASSERT(!idxs.empty());
return idxs[0].size();
}
size_t n_stream() const {
return strm.size();
}
bool empty() const {
return idxs.empty();
}
void clear() {
idxs.clear();
}
bool is_contiguous() const {
if (idxs.empty() || idxs[0].empty()) {
return true;
}
if (idxs.size() > 1) {
return false;
}
const uint32_t h = idxs[0][0];
for (size_t i = 0; i < idxs[0].size(); ++i) {
if (idxs[0][i] != h + i) {
return false;
}
}
return true;
}
};
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache() = default;
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
uint32_t get_size() const;
uint32_t get_n_stream() const;
bool get_has_shift() const;
ggml_type type_k() const;
ggml_type type_v() const;
uint32_t get_n_kv(const slot_info & sinfo) const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_k_shift(ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_rot(ggml_tensor * dst) const;
void set_input_v_rot(ggml_tensor * dst) const;
private:
const llama_model & model;
const llama_hparams & hparams;
struct kv_layer {
uint32_t il;
ggml_tensor * k;
ggml_tensor * v;
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
};
bool v_trans = true;
const uint32_t n_seq_max = 1;
const uint32_t n_stream = 1;
const uint32_t n_pad = 1;
const uint32_t n_swa = 0;
bool attn_rot_k = false;
bool attn_rot_v = false;
int32_t n_embd_head_k_all = 0;
int32_t n_embd_head_v_all = 0;
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
int debug = 0;
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells> v_cells;
std::vector<uint32_t> seq_to_stream;
stream_copy_info sc_info;
std::vector<kv_layer> layers;
std::unordered_map<int32_t, int32_t> map_layer_ids;
size_t total_size() const;
size_t size_k_bytes() const;
size_t size_v_bytes() const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * rot,
ggml_tensor * factors,
float freq_base,
float freq_scale,
uint32_t il) const;
ggml_cgraph * build_graph_shift(
llm_graph_result * res,
llama_context * lctx) const;
struct cell_ranges_t {
uint32_t strm;
std::vector<std::pair<uint32_t, uint32_t>> data; };
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
};
class llama_kv_cache_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
using stream_copy_info = llama_kv_cache::stream_copy_info;
llama_kv_cache_context(llama_memory_status status);
llama_kv_cache_context(
llama_kv_cache * kv);
llama_kv_cache_context(
llama_kv_cache * kv,
llama_context * lctx,
bool do_shift,
stream_copy_info sc_info);
llama_kv_cache_context(
llama_kv_cache * kv,
slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_context();
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
uint32_t get_n_kv() const;
ggml_type type_k() const;
ggml_type type_v() const;
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_rot(ggml_tensor * dst) const;
void set_input_v_rot(ggml_tensor * dst) const;
private:
llama_memory_status status;
llama_kv_cache * kv;
llama_context * lctx;
bool do_shift = false;
stream_copy_info sc_info;
size_t i_cur = 0;
slot_info_vec_t sinfos;
std::vector<llama_ubatch> ubatches;
int32_t n_kv;
};