sipp-rs 0.1.0

Unified Rust library for extensible Sipp inference
//! KV-cache space management and prompt preparation for prefill.
//!
//! Three primitives:
//! - `ensure_context_space` slides the KV window when the sequence would
//!   exceed `n_ctx`, preserving a retained prefix.
//! - `prepare_sequence_for_prompt` runs LCP reuse + optional snapshot
//!   restore and returns the number of cache-hit tokens.
//! - `ensure_decode_step_context_space` is the per-decode-step variant.

use std::cmp;

use crate::native_bridge::NativeRuntimeHandle;
use crate::runtime::config::KvReuseMode;
use crate::runtime::metrics::CacheSource;
use crate::runtime::scheduler::SlotState;
use crate::runtime::session::{CacheCandidate, CachePreparation, KvCacheManager, SequenceMirror};
use crate::runtime::{llama_seq_id, llama_token};

use super::numeric::{nonnegative_i32_to_usize_opt, usize_to_i32};

#[inline]
pub(super) fn resolve_initial_decode_context_reservation(
    max_output_tokens: i32,
    decode_reserve: i32,
) -> i32 {
    if max_output_tokens <= 0 {
        0
    } else {
        max_output_tokens.min(decode_reserve.max(1))
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) struct PrefixReusePlan {
    pub live: bool,
    pub snapshot: bool,
}

pub(super) fn prefix_reuse_plan(mode: KvReuseMode, bypass_prefix_cache: bool) -> PrefixReusePlan {
    if bypass_prefix_cache {
        return PrefixReusePlan {
            live: false,
            snapshot: false,
        };
    }

    match mode {
        KvReuseMode::Disabled => PrefixReusePlan {
            live: false,
            snapshot: false,
        },
        KvReuseMode::LiveSlotPrefix => PrefixReusePlan {
            live: true,
            snapshot: false,
        },
        KvReuseMode::StateSnapshot => PrefixReusePlan {
            live: false,
            snapshot: true,
        },
        KvReuseMode::LiveSlotAndSnapshot => PrefixReusePlan {
            live: true,
            snapshot: true,
        },
    }
}

pub(super) fn live_candidate_lcp(
    plan: PrefixReusePlan,
    cache_candidate: CacheCandidate,
    cached_tokens: &[llama_token],
    prompt_tokens: &[llama_token],
    allow_partial_kv: bool,
) -> usize {
    if !plan.live || cache_candidate != CacheCandidate::Live {
        return 0;
    }
    let lcp = compute_lcp_reuse(cached_tokens, prompt_tokens);
    if allow_partial_kv {
        return lcp;
    }
    if cached_tokens.len() >= prompt_tokens.len() {
        return 0;
    }
    if lcp == cached_tokens.len() {
        lcp
    } else {
        0
    }
}

pub(super) fn authorized_lcp(
    source: CacheSource,
    cached_tokens: &[llama_token],
    prompt_tokens: &[llama_token],
) -> usize {
    if source == CacheSource::None {
        0
    } else {
        compute_lcp_reuse(cached_tokens, prompt_tokens)
    }
}

/// Slides the KV window so `state.n_past + new_tokens_needed <= n_ctx`,
/// preserving `retained_prefix_tokens` at the head. Returns `false` if the
/// requested window cannot fit inside the active context.
pub(super) fn ensure_context_space(
    native_runtime: &mut NativeRuntimeHandle,
    retained_prefix_tokens: i32,
    state: &mut SequenceMirror,
    seq_id: llama_seq_id,
    new_tokens_needed: i32,
) -> bool {
    if seq_id < 0 {
        return false;
    }
    if new_tokens_needed <= 0 {
        return true;
    }

    let n_ctx = native_runtime.n_ctx();
    if n_ctx <= 0 || new_tokens_needed > n_ctx {
        return false;
    }
    let Some(total_needed) = state.n_past.checked_add(new_tokens_needed) else {
        return false;
    };
    if total_needed <= n_ctx {
        return true;
    }

    let n_keep = retained_prefix_tokens.min(state.n_past).max(0);
    let required_discard = total_needed - n_ctx;
    let max_discard = (state.n_past - n_keep).max(0);
    let n_discard = required_discard.clamp(0, max_discard);

    if n_discard <= 0 {
        if !native_runtime.clear_sequence(seq_id, 0, -1) {
            return false;
        }
        state.current_kv_tokens.clear();
        state.n_past = 0;
        return true;
    }

    let Some(discard_end) = n_keep.checked_add(n_discard) else {
        return false;
    };

    if !native_runtime.clear_sequence(seq_id, n_keep, discard_end) {
        return false;
    }
    native_runtime.add_sequence_delta(seq_id, discard_end, -1, -n_discard);

    let Some(n_keep_len) = nonnegative_i32_to_usize_opt(n_keep) else {
        return false;
    };
    let Some(discard_end_len) = nonnegative_i32_to_usize_opt(discard_end) else {
        return false;
    };
    if state.current_kv_tokens.len() > n_keep_len {
        let erase_end = cmp::min(discard_end_len, state.current_kv_tokens.len());
        state.current_kv_tokens.drain(n_keep_len..erase_end);
    } else {
        state.current_kv_tokens.clear();
    }
    let Some(n_past) = usize_to_i32(state.current_kv_tokens.len()) else {
        return false;
    };
    state.n_past = n_past;

    let Some(total_needed) = state.n_past.checked_add(new_tokens_needed) else {
        return false;
    };
    if total_needed <= n_ctx {
        return true;
    }

    if !native_runtime.clear_sequence(seq_id, 0, -1) {
        return false;
    }
    state.current_kv_tokens.clear();
    state.n_past = 0;
    true
}

/// Drives prefix reuse for an admitted request:
///   1. live LCP against the session's existing KV tokens,
///   2. optional restore from the snapshot prefix cache,
///   3. trim the KV to the final match length (honoring recurrent/hybrid
///      model constraints),
///   4. ensure room for the missing prompt tokens + decode reservation.
///
/// Writes the prefill cursor (tokens already in KV) into `out_prefill_cursor`
/// and returns cache preparation metrics.
#[allow(clippy::too_many_arguments)]
pub(super) fn prepare_sequence_for_prompt(
    native_runtime: &mut NativeRuntimeHandle,
    retained_prefix_tokens: i32,
    cache_mode: KvReuseMode,
    bypass_prefix_cache: bool,
    decode_token_reserve: i32,
    model_fingerprint: u64,
    kv_cache: &mut KvCacheManager,
    cache_candidate: CacheCandidate,
    requires_kv_clear: bool,
    context_key: &str,
    prompt_tokens: &[llama_token],
    n_tokens_predict: i32,
    state: &mut SequenceMirror,
    seq_id: llama_seq_id,
    out_prefill_cursor: &mut usize,
) -> Option<CachePreparation> {
    *out_prefill_cursor = 0;
    if seq_id < 0 || prompt_tokens.is_empty() {
        return None;
    }

    let allow_partial_kv = !(native_runtime.is_recurrent() || native_runtime.is_hybrid());
    let reuse_plan = prefix_reuse_plan(cache_mode, bypass_prefix_cache);
    let live_match_len = live_candidate_lcp(
        reuse_plan,
        cache_candidate,
        &state.current_kv_tokens,
        prompt_tokens,
        allow_partial_kv,
    );
    let mut match_len = live_match_len;
    let mut cache_source = if live_match_len > 0 {
        CacheSource::Live
    } else {
        CacheSource::None
    };

    // Snapshot restore is manager-owned; prefill receives only the prefix mirror.
    if reuse_plan.snapshot {
        if let Some(snapshot) = kv_cache.restore_best_snapshot_prefix(
            native_runtime,
            seq_id,
            model_fingerprint,
            context_key,
            prompt_tokens,
            live_match_len,
        ) {
            state.current_kv_tokens = snapshot.prefix_tokens;
            state.n_past = usize_to_i32(snapshot.token_count)?;
            match_len = snapshot.token_count.min(prompt_tokens.len());
            cache_source = CacheSource::Snapshot;
        }
    }

    if cache_source == CacheSource::None {
        if requires_kv_clear || !state.current_kv_tokens.is_empty() {
            clear_sequence_state(native_runtime, seq_id, state);
        }
        match_len = 0;
    }

    // Re-run LCP — it can grow after a cache restore (but never after `ensure_context_space`).
    match_len = match_len.max(authorized_lcp(
        cache_source,
        &state.current_kv_tokens,
        prompt_tokens,
    ));
    let missing_prompt_tokens = prompt_tokens.len().checked_sub(match_len)?;
    let tokens_to_add = usize_to_i32(missing_prompt_tokens)?;
    let total_needed = tokens_to_add
        + resolve_initial_decode_context_reservation(n_tokens_predict, decode_token_reserve);

    if !ensure_context_space(
        native_runtime,
        retained_prefix_tokens,
        state,
        seq_id,
        total_needed,
    ) {
        return None;
    }

    match_len = match_len.min(authorized_lcp(
        cache_source,
        &state.current_kv_tokens,
        prompt_tokens,
    ));
    if match_len < state.current_kv_tokens.len() {
        if !allow_partial_kv {
            native_runtime.clear_sequence(seq_id, 0, -1);
            state.current_kv_tokens.clear();
            state.n_past = 0;
            match_len = 0;
            cache_source = CacheSource::None;
        } else {
            let match_len_i32 = usize_to_i32(match_len)?;
            if !native_runtime.clear_sequence(seq_id, match_len_i32, -1) {
                return None;
            }
            state.current_kv_tokens.truncate(match_len);
            state.n_past = match_len_i32;
        }
    }

    // Full-prompt cache hit needs a token to drive decode: trim 1 from KV or
    // invalidate.
    if match_len == prompt_tokens.len() && match_len > 0 {
        if allow_partial_kv {
            let match_len_i32 = usize_to_i32(match_len)?;
            let last_token_position = match_len_i32.checked_sub(1)?;
            if !native_runtime.clear_sequence(seq_id, last_token_position, -1) {
                return None;
            }
            state.current_kv_tokens.truncate(match_len - 1);
            state.n_past = last_token_position;
            match_len -= 1;
        } else {
            native_runtime.clear_sequence(seq_id, 0, -1);
            state.current_kv_tokens.clear();
            state.n_past = 0;
            match_len = 0;
        }
    }

    let cache_hits = usize_to_i32(match_len)?;
    *out_prefill_cursor = match_len;
    if cache_hits == 0 {
        cache_source = CacheSource::None;
    }
    Some(CachePreparation {
        source: cache_source,
        cache_hits,
    })
}

fn clear_sequence_state(
    native_runtime: &mut NativeRuntimeHandle,
    seq_id: llama_seq_id,
    state: &mut SequenceMirror,
) {
    native_runtime.clear_sequence(seq_id, 0, -1);
    state.current_kv_tokens.clear();
    state.n_past = 0;
}

/// Per-step variant: makes room for one more decode token. For multimodal
/// turns, the additional token must fit strictly within `n_ctx` (no eviction
/// of the multimodal prefix is allowed).
pub(super) fn ensure_decode_step_context_space(
    native_runtime: &mut NativeRuntimeHandle,
    retained_prefix_tokens: i32,
    slot: &mut SlotState,
) -> bool {
    if slot.request().is_none() {
        return false;
    }
    if slot.generated_tokens.is_empty() {
        return true;
    }
    if slot
        .request()
        .is_some_and(|request| request.is_multimodal_turn)
        && slot
            .mirror
            .n_past
            .checked_add(1)
            .is_none_or(|needed| needed > native_runtime.n_ctx())
    {
        return false;
    }
    ensure_context_space(
        native_runtime,
        retained_prefix_tokens,
        &mut slot.mirror,
        slot.seq_id,
        1,
    )
}

fn compute_lcp_reuse(cached_tokens: &[llama_token], incoming_tokens: &[llama_token]) -> usize {
    cached_tokens
        .iter()
        .zip(incoming_tokens.iter())
        .take_while(|(cached, incoming)| cached == incoming)
        .count()
}