use std::collections::VecDeque;
use std::num::NonZeroU32;
use crate::slot::SlotEntry;
use crate::types::{CheckpointParams, KvCacheParams};
struct Checkpoint {
#[allow(dead_code)]
pos_min: i32,
pos_max: i32,
n_tokens: usize,
data: Vec<u8>,
}
pub(crate) struct PersistentCtx<'m> {
pub(crate) ctx: llama_cpp_2::context::LlamaContext<'m>,
pub(crate) last_entries: Vec<SlotEntry>,
pub(crate) trim_unsupported: bool,
checkpoints: VecDeque<Checkpoint>,
}
impl PersistentCtx<'_> {
pub(crate) fn retain_checkpoints_below(&mut self, upper: usize) {
self.checkpoints.retain(|c| (c.pos_max as usize) < upper);
}
#[cfg(feature = "mtmd")]
pub(crate) fn clear_checkpoints(&mut self) {
self.checkpoints.clear();
}
pub(crate) fn checkpoint_count(&self) -> usize {
self.checkpoints.len()
}
}
pub(crate) fn ensure_persistent_ctx<'a, 'm>(
backend: &'m llama_cpp_2::llama_backend::LlamaBackend,
model: &'m llama_cpp_2::model::LlamaModel,
n_ctx: u32,
kv_cache: &KvCacheParams,
persistent: &'a mut Option<PersistentCtx<'m>>,
) -> Result<&'a mut PersistentCtx<'m>, String> {
use llama_cpp_2::context::params::LlamaContextParams;
if persistent.is_none() {
let ctx_params = LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(n_ctx))
.with_type_k(kv_cache.type_k.into())
.with_type_v(kv_cache.type_v.into());
let ctx = model
.new_context(backend, ctx_params)
.map_err(|e| format!("Context creation failed: {e}"))?;
*persistent = Some(PersistentCtx {
ctx,
last_entries: Vec::new(),
trim_unsupported: false,
checkpoints: VecDeque::new(),
});
}
Ok(persistent
.as_mut()
.expect("persistent context was just initialised above"))
}
pub(crate) fn restore_or_clear(p: &mut PersistentCtx<'_>, cached: usize) -> usize {
let candidate_idx = p
.checkpoints
.iter()
.rposition(|c| c.n_tokens <= cached && (c.pos_max as usize) < cached);
if let Some(idx) = candidate_idx {
let n_tokens = p.checkpoints[idx].n_tokens;
let restored = unsafe {
p.ctx.state_seq_set_data_ext(
&p.checkpoints[idx].data,
0,
llama_cpp_2::context::session::LlamaStateSeqFlags::PARTIAL_ONLY,
)
};
if restored {
let _ = p
.ctx
.clear_kv_cache_seq(Some(0), Some(n_tokens as u32), None);
p.last_entries.truncate(n_tokens);
p.checkpoints.truncate(idx + 1);
log::debug!("restored checkpoint at n_tokens={n_tokens} (cached LCP was {cached})");
return n_tokens;
}
log::warn!("state_seq_set_data_ext failed; clearing cache.");
}
p.ctx.clear_kv_cache();
p.last_entries.clear();
p.checkpoints.clear();
0
}
pub(crate) fn maybe_create_checkpoint(
p: &mut PersistentCtx<'_>,
params: CheckpointParams,
n_tokens_decoded: usize,
prompt_len: usize,
) {
if params.max_checkpoints == 0 {
return;
}
if n_tokens_decoded < params.min_tokens as usize {
return;
}
let n_ubatch = p.ctx.n_ubatch().max(1) as usize;
let near_end =
n_tokens_decoded + 4 + n_ubatch == prompt_len || n_tokens_decoded + 4 == prompt_len;
let last_n_tokens = p.checkpoints.back().map(|c| c.n_tokens).unwrap_or(0);
let cadence_ok = params.every_n_tokens > 0
&& n_tokens_decoded.saturating_sub(last_n_tokens) >= params.every_n_tokens as usize;
if !(near_end || cadence_ok) {
return;
}
if p.checkpoints
.back()
.is_some_and(|c| n_tokens_decoded.saturating_sub(c.n_tokens) < params.min_gap as usize)
{
return;
}
let size = p.ctx.state_seq_get_size_ext(
0,
llama_cpp_2::context::session::LlamaStateSeqFlags::PARTIAL_ONLY,
);
if size == 0 {
return;
}
let mut data = vec![0u8; size];
let written = unsafe {
p.ctx.state_seq_get_data_ext(
data.as_mut_ptr(),
0,
llama_cpp_2::context::session::LlamaStateSeqFlags::PARTIAL_ONLY,
)
};
if written == 0 {
return;
}
data.truncate(written);
while p.checkpoints.len() >= params.max_checkpoints as usize {
p.checkpoints.pop_front();
}
let pos_max = (n_tokens_decoded as i32).saturating_sub(1);
p.checkpoints.push_back(Checkpoint {
pos_min: 0,
pos_max,
n_tokens: n_tokens_decoded,
data,
});
log::debug!(
"checkpoint created at n_tokens={n_tokens_decoded} (size={} KiB, total={})",
written / 1024,
p.checkpoints.len(),
);
}