sipp-rs 0.1.0

Unified Rust library for extensible Sipp inference
use crate::native_bridge::NativeRuntimeHandle;
use crate::runtime::config::NativeRuntimeConfig;
use crate::runtime::request::GenerateRequestLifecycle;
use crate::runtime::request::RequestQueue;
use crate::runtime::scheduler::{PrefillKind, SlotPhase, SlotState, TerminalAction};
use crate::runtime::session::KvCacheManager;
use crate::runtime::REQUEST_CANCELLED_MESSAGE;

use super::environment::{live_retained_prefix_tokens, resolve_batch_token_budget};
use super::multimodal::run_multimodal_prefill;
use super::prefill::{ensure_decode_step_context_space, prepare_sequence_for_prompt};
use super::InferenceRuntime;

mod recovery;
mod sampler_attach;

use recovery::normalize_runnable_slot_state;
use sampler_attach::ensure_slot_sampler;

/////////////////////////////////////////////////////////////////////////////////
/// TESTS
/////////////////////////////////////////////////////////////////////////////////

#[cfg(test)]
#[path = "../../../tests/runtime/inference_runtime/slot_tests.rs"]
mod slot_tests;

/////////////////////////////////////////////////////////////////////////////////
/// SRC
/////////////////////////////////////////////////////////////////////////////////

impl InferenceRuntime {
    pub(super) fn normalize_slots_for_tick(&mut self) {
        let slot_count = self.slot_scheduler.slots.len();
        for slot_index in 0..slot_count {
            let slot = &mut self.slot_scheduler.slots[slot_index];
            if slot.seq_id < 0 {
                continue;
            }
            let Some(request) = slot.request() else {
                continue;
            };
            if request.cancel_requested {
                slot.cancel(REQUEST_CANCELLED_MESSAGE);
                continue;
            }

            normalize_runnable_slot_state(
                slot,
                &mut self.native_runtime,
                live_retained_prefix_tokens(&self.config),
            );

            // Embedding-only slots have no sampler; any resident sampler for
            // this physical sequence belongs to a previous text request.
            if slot.plan.terminal != TerminalAction::SampleTokens {
                let seq_id = slot.seq_id;
                if seq_id >= 0 && self.resident_backend_samplers.remove(&seq_id).is_some() {
                    self.native_runtime.detach_sampler(seq_id);
                }
            } else if slot.sampler.is_none() {
                if !ensure_slot_sampler(
                    slot,
                    &mut self.native_runtime,
                    &self.config,
                    &mut self.sampler_pool,
                    &mut self.resident_backend_samplers,
                ) {
                    continue;
                }
            }

            if slot.phase == SlotPhase::Prefill && slot.prefill_cursor == 0 {
                if run_initial_prefill(
                    slot,
                    &mut self.native_runtime,
                    &self.config,
                    self.model_fingerprint,
                    &mut self.kv_cache,
                    &mut self.total_cache_hits,
                    &mut self.request_queue,
                    &mut self.scratch_token_piece,
                ) {
                    continue;
                }
            }

            if slot.phase == SlotPhase::Decode
                && !ensure_decode_step_context_space(
                    &mut self.native_runtime,
                    live_retained_prefix_tokens(&self.config),
                    slot,
                )
            {
                slot.fail("Failed to extend decode context headroom.");
                continue;
            }

            if let Some(request) = slot.request_mut() {
                request.lifecycle = GenerateRequestLifecycle::Running;
            }
        }
    }
}

#[allow(clippy::too_many_arguments)]
fn run_initial_prefill(
    slot: &mut SlotState,
    native_runtime: &mut NativeRuntimeHandle,
    config: &NativeRuntimeConfig,
    model_fingerprint: u64,
    kv_cache: &mut KvCacheManager,
    total_cache_hits: &mut usize,
    request_queue: &mut RequestQueue,
    scratch_token_piece: &mut Vec<u8>,
) -> bool {
    if slot
        .request()
        .is_some_and(|request| request.is_multimodal_turn)
    {
        let ok = run_multimodal_prefill(
            native_runtime,
            resolve_batch_token_budget(native_runtime, config),
            request_queue,
            slot,
            scratch_token_piece,
        );
        if !ok {
            if slot.terminal_error_message.is_empty() {
                slot.terminal_error_message = "Failed to evaluate multimodal prompt.".to_string();
            }
            slot.phase = SlotPhase::Failed;
            if let Some(request) = slot.request_mut() {
                request.lifecycle = GenerateRequestLifecycle::Failed;
                request.multimodal = None;
            }
        }
        return true;
    }

    // Encoder-decoder prompts are rewritten to a single decoder-start token by
    // the admission pass. Allowing the prefix cache to LCP-match that token
    // across unrelated source prompts would let an old turn's decoder KV
    // poison the new turn — disable cache reuse and start from a fresh KV
    // state for this slot. Same rule for embedding-context requests, whose
    // outputs are read directly from the encoder pass, not from cached KV.
    let bypass_prefix_cache = slot.plan.prefill == PrefillKind::Encode
        || slot.plan.terminal == TerminalAction::ReadEmbedding;
    let cache_candidate = slot.cache_candidate;
    let requires_kv_clear = slot.requires_kv_clear;
    let Some(ref mut request) = slot.request else {
        return false;
    };
    let mut prefill_cursor = 0;
    if let Some(cache_preparation) = prepare_sequence_for_prompt(
        native_runtime,
        live_retained_prefix_tokens(config),
        config.cache.mode,
        bypass_prefix_cache,
        config.scheduler.policy.decode_token_reserve,
        model_fingerprint,
        kv_cache,
        cache_candidate,
        requires_kv_clear,
        &request.context_key,
        &request.prompt_tokens,
        request.max_output_tokens,
        &mut slot.mirror,
        slot.seq_id,
        &mut prefill_cursor,
    ) {
        request.cache_hits = cache_preparation.cache_hits;
        if cache_preparation.cache_hits > 0 {
            *total_cache_hits =
                total_cache_hits.saturating_add(cache_preparation.cache_hits as usize);
        }
        request.cache_source = cache_preparation.source;
        slot.requires_kv_clear = false;
        if !slot.sampler_prompt_seeded
            && request.grammar.is_empty()
            && request.json_schema.is_empty()
        {
            if let Some(sampler) = slot.sampler.as_mut() {
                let seed_start = config.prompt_sampler_seed_start(
                    request.sampling.as_ref(),
                    request.prompt_tokens.len(),
                );
                for &token in &request.prompt_tokens[seed_start..] {
                    if !sampler.accept(token, false) {
                        break;
                    }
                }
                slot.sampler_prompt_seeded = true;
            }
        }
        slot.prefill_cursor = prefill_cursor;
        slot.phase = if slot.prefill_cursor >= request.prompt_tokens.len() {
            SlotPhase::Decode
        } else {
            SlotPhase::Prefill
        };
    } else {
        slot.terminal_error_message = "Failed to prepare sequence for prompt reuse.".to_string();
        slot.phase = SlotPhase::Failed;
        request.lifecycle = GenerateRequestLifecycle::Failed;
    }
    false
}