#pragma once
#include "llama-kv-cache.h"
#include <vector>
class llama_kv_cache_dsa : public llama_memory_i {
public:
llama_kv_cache_dsa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache_dsa() = default;
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
llama_kv_cache * get_mla() const;
llama_kv_cache * get_lid() const;
private:
llama_hparams hparams_lid;
const uint32_t n_stream = 1;
std::unique_ptr<llama_kv_cache> kv_mla;
std::unique_ptr<llama_kv_cache> kv_lid;
};
class llama_kv_cache_dsa_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
llama_kv_cache_dsa_context(llama_memory_status status);
llama_kv_cache_dsa_context(
llama_kv_cache_dsa * kv);
llama_kv_cache_dsa_context(
llama_kv_cache_dsa * kv,
llama_context * lctx,
bool optimize);
llama_kv_cache_dsa_context(
llama_kv_cache_dsa * kv,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_ik,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_dsa_context();
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
const llama_kv_cache_context * get_mla() const;
const llama_kv_cache_context * get_lid() const;
private:
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
const llama_memory_context_ptr ctx_mla;
const llama_memory_context_ptr ctx_lid;
const llama_memory_status status;
};