use std::cmp;
use crate::native_bridge::NativeRuntimeHandle;
use crate::runtime::config::KvReuseMode;
use crate::runtime::metrics::CacheSource;
use crate::runtime::scheduler::SlotState;
use crate::runtime::session::{CacheCandidate, CachePreparation, KvCacheManager, SequenceMirror};
use crate::runtime::{llama_seq_id, llama_token};
use super::numeric::{nonnegative_i32_to_usize_opt, usize_to_i32};
#[inline]
pub(super) fn resolve_initial_decode_context_reservation(
max_output_tokens: i32,
decode_reserve: i32,
) -> i32 {
if max_output_tokens <= 0 {
0
} else {
max_output_tokens.min(decode_reserve.max(1))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) struct PrefixReusePlan {
pub live: bool,
pub snapshot: bool,
}
pub(super) fn prefix_reuse_plan(mode: KvReuseMode, bypass_prefix_cache: bool) -> PrefixReusePlan {
if bypass_prefix_cache {
return PrefixReusePlan {
live: false,
snapshot: false,
};
}
match mode {
KvReuseMode::Disabled => PrefixReusePlan {
live: false,
snapshot: false,
},
KvReuseMode::LiveSlotPrefix => PrefixReusePlan {
live: true,
snapshot: false,
},
KvReuseMode::StateSnapshot => PrefixReusePlan {
live: false,
snapshot: true,
},
KvReuseMode::LiveSlotAndSnapshot => PrefixReusePlan {
live: true,
snapshot: true,
},
}
}
pub(super) fn live_candidate_lcp(
plan: PrefixReusePlan,
cache_candidate: CacheCandidate,
cached_tokens: &[llama_token],
prompt_tokens: &[llama_token],
allow_partial_kv: bool,
) -> usize {
if !plan.live || cache_candidate != CacheCandidate::Live {
return 0;
}
let lcp = compute_lcp_reuse(cached_tokens, prompt_tokens);
if allow_partial_kv {
return lcp;
}
if cached_tokens.len() >= prompt_tokens.len() {
return 0;
}
if lcp == cached_tokens.len() {
lcp
} else {
0
}
}
pub(super) fn authorized_lcp(
source: CacheSource,
cached_tokens: &[llama_token],
prompt_tokens: &[llama_token],
) -> usize {
if source == CacheSource::None {
0
} else {
compute_lcp_reuse(cached_tokens, prompt_tokens)
}
}
pub(super) fn ensure_context_space(
native_runtime: &mut NativeRuntimeHandle,
retained_prefix_tokens: i32,
state: &mut SequenceMirror,
seq_id: llama_seq_id,
new_tokens_needed: i32,
) -> bool {
if seq_id < 0 {
return false;
}
if new_tokens_needed <= 0 {
return true;
}
let n_ctx = native_runtime.n_ctx();
if n_ctx <= 0 || new_tokens_needed > n_ctx {
return false;
}
let Some(total_needed) = state.n_past.checked_add(new_tokens_needed) else {
return false;
};
if total_needed <= n_ctx {
return true;
}
let n_keep = retained_prefix_tokens.min(state.n_past).max(0);
let required_discard = total_needed - n_ctx;
let max_discard = (state.n_past - n_keep).max(0);
let n_discard = required_discard.clamp(0, max_discard);
if n_discard <= 0 {
if !native_runtime.clear_sequence(seq_id, 0, -1) {
return false;
}
state.current_kv_tokens.clear();
state.n_past = 0;
return true;
}
let Some(discard_end) = n_keep.checked_add(n_discard) else {
return false;
};
if !native_runtime.clear_sequence(seq_id, n_keep, discard_end) {
return false;
}
native_runtime.add_sequence_delta(seq_id, discard_end, -1, -n_discard);
let Some(n_keep_len) = nonnegative_i32_to_usize_opt(n_keep) else {
return false;
};
let Some(discard_end_len) = nonnegative_i32_to_usize_opt(discard_end) else {
return false;
};
if state.current_kv_tokens.len() > n_keep_len {
let erase_end = cmp::min(discard_end_len, state.current_kv_tokens.len());
state.current_kv_tokens.drain(n_keep_len..erase_end);
} else {
state.current_kv_tokens.clear();
}
let Some(n_past) = usize_to_i32(state.current_kv_tokens.len()) else {
return false;
};
state.n_past = n_past;
let Some(total_needed) = state.n_past.checked_add(new_tokens_needed) else {
return false;
};
if total_needed <= n_ctx {
return true;
}
if !native_runtime.clear_sequence(seq_id, 0, -1) {
return false;
}
state.current_kv_tokens.clear();
state.n_past = 0;
true
}
#[allow(clippy::too_many_arguments)]
pub(super) fn prepare_sequence_for_prompt(
native_runtime: &mut NativeRuntimeHandle,
retained_prefix_tokens: i32,
cache_mode: KvReuseMode,
bypass_prefix_cache: bool,
decode_token_reserve: i32,
model_fingerprint: u64,
kv_cache: &mut KvCacheManager,
cache_candidate: CacheCandidate,
requires_kv_clear: bool,
context_key: &str,
prompt_tokens: &[llama_token],
n_tokens_predict: i32,
state: &mut SequenceMirror,
seq_id: llama_seq_id,
out_prefill_cursor: &mut usize,
) -> Option<CachePreparation> {
*out_prefill_cursor = 0;
if seq_id < 0 || prompt_tokens.is_empty() {
return None;
}
let allow_partial_kv = !(native_runtime.is_recurrent() || native_runtime.is_hybrid());
let reuse_plan = prefix_reuse_plan(cache_mode, bypass_prefix_cache);
let live_match_len = live_candidate_lcp(
reuse_plan,
cache_candidate,
&state.current_kv_tokens,
prompt_tokens,
allow_partial_kv,
);
let mut match_len = live_match_len;
let mut cache_source = if live_match_len > 0 {
CacheSource::Live
} else {
CacheSource::None
};
if reuse_plan.snapshot {
if let Some(snapshot) = kv_cache.restore_best_snapshot_prefix(
native_runtime,
seq_id,
model_fingerprint,
context_key,
prompt_tokens,
live_match_len,
) {
state.current_kv_tokens = snapshot.prefix_tokens;
state.n_past = usize_to_i32(snapshot.token_count)?;
match_len = snapshot.token_count.min(prompt_tokens.len());
cache_source = CacheSource::Snapshot;
}
}
if cache_source == CacheSource::None {
if requires_kv_clear || !state.current_kv_tokens.is_empty() {
clear_sequence_state(native_runtime, seq_id, state);
}
match_len = 0;
}
match_len = match_len.max(authorized_lcp(
cache_source,
&state.current_kv_tokens,
prompt_tokens,
));
let missing_prompt_tokens = prompt_tokens.len().checked_sub(match_len)?;
let tokens_to_add = usize_to_i32(missing_prompt_tokens)?;
let total_needed = tokens_to_add
+ resolve_initial_decode_context_reservation(n_tokens_predict, decode_token_reserve);
if !ensure_context_space(
native_runtime,
retained_prefix_tokens,
state,
seq_id,
total_needed,
) {
return None;
}
match_len = match_len.min(authorized_lcp(
cache_source,
&state.current_kv_tokens,
prompt_tokens,
));
if match_len < state.current_kv_tokens.len() {
if !allow_partial_kv {
native_runtime.clear_sequence(seq_id, 0, -1);
state.current_kv_tokens.clear();
state.n_past = 0;
match_len = 0;
cache_source = CacheSource::None;
} else {
let match_len_i32 = usize_to_i32(match_len)?;
if !native_runtime.clear_sequence(seq_id, match_len_i32, -1) {
return None;
}
state.current_kv_tokens.truncate(match_len);
state.n_past = match_len_i32;
}
}
if match_len == prompt_tokens.len() && match_len > 0 {
if allow_partial_kv {
let match_len_i32 = usize_to_i32(match_len)?;
let last_token_position = match_len_i32.checked_sub(1)?;
if !native_runtime.clear_sequence(seq_id, last_token_position, -1) {
return None;
}
state.current_kv_tokens.truncate(match_len - 1);
state.n_past = last_token_position;
match_len -= 1;
} else {
native_runtime.clear_sequence(seq_id, 0, -1);
state.current_kv_tokens.clear();
state.n_past = 0;
match_len = 0;
}
}
let cache_hits = usize_to_i32(match_len)?;
*out_prefill_cursor = match_len;
if cache_hits == 0 {
cache_source = CacheSource::None;
}
Some(CachePreparation {
source: cache_source,
cache_hits,
})
}
fn clear_sequence_state(
native_runtime: &mut NativeRuntimeHandle,
seq_id: llama_seq_id,
state: &mut SequenceMirror,
) {
native_runtime.clear_sequence(seq_id, 0, -1);
state.current_kv_tokens.clear();
state.n_past = 0;
}
pub(super) fn ensure_decode_step_context_space(
native_runtime: &mut NativeRuntimeHandle,
retained_prefix_tokens: i32,
slot: &mut SlotState,
) -> bool {
if slot.request().is_none() {
return false;
}
if slot.generated_tokens.is_empty() {
return true;
}
if slot
.request()
.is_some_and(|request| request.is_multimodal_turn)
&& slot
.mirror
.n_past
.checked_add(1)
.is_none_or(|needed| needed > native_runtime.n_ctx())
{
return false;
}
ensure_context_space(
native_runtime,
retained_prefix_tokens,
&mut slot.mirror,
slot.seq_id,
1,
)
}
fn compute_lcp_reuse(cached_tokens: &[llama_token], incoming_tokens: &[llama_token]) -> usize {
cached_tokens
.iter()
.zip(incoming_tokens.iter())
.take_while(|(cached, incoming)| cached == incoming)
.count()
}