sipp-rs 0.1.0

Unified Rust library for extensible Sipp inference
use std::collections::HashMap;

use crate::native_bridge::NativeRuntimeHandle;
use crate::runtime::config::NativeRuntimeConfig;
use crate::runtime::llama_seq_id;
use crate::runtime::scheduler::{SamplerCacheKey, SamplerHandle, SlotState};

use super::super::sampler::{attach_backend_sampler, create_sampler, ResidentBackendSampler};

pub(super) fn ensure_slot_sampler(
    slot: &mut SlotState,
    native_runtime: &mut NativeRuntimeHandle,
    config: &NativeRuntimeConfig,
    sampler_pool: &mut HashMap<SamplerCacheKey, Vec<SamplerHandle>>,
    resident_backend_samplers: &mut HashMap<llama_seq_id, ResidentBackendSampler>,
) -> bool {
    let Some(request) = slot.request() else {
        return false;
    };

    let sampling_json = match config.try_sampling_json_with_override(request.sampling.as_ref()) {
        Ok(sampling_json) => sampling_json,
        Err(error) => {
            slot.fail(format!(
                "Failed to serialize sampler configuration: {error}"
            ));
            return false;
        }
    };
    let key = SamplerCacheKey {
        sampling_json,
        grammar: request.grammar.clone(),
        json_schema: request.json_schema.clone(),
    };

    if let Some(resident) = resident_backend_samplers.remove(&slot.seq_id) {
        if resident.key == key {
            slot.set_sampler(resident.sampler);
            slot.sampler_key = Some(key);
            slot.backend_sampler_attached = true;
            return true;
        }

        if slot.seq_id >= 0 {
            native_runtime.detach_sampler(slot.seq_id);
        }
    }

    if let Some(sampler) = sampler_pool.get_mut(&key).and_then(|vec| vec.pop()) {
        slot.set_sampler(sampler);
        slot.sampler_key = Some(key);
        attach_backend_sampler(native_runtime, slot);
        return true;
    }

    let sampling = slot.request().and_then(|request| request.sampling.as_ref());
    match create_sampler(
        native_runtime,
        config,
        sampling,
        Some(&key.grammar),
        Some(&key.json_schema),
    ) {
        Ok(sampler) => {
            slot.set_sampler(sampler);
            slot.sampler_key = Some(key);
            attach_backend_sampler(native_runtime, slot);
            true
        }
        Err(_) => {
            let message = if key.grammar.is_empty() {
                "Failed to create per-slot sampler."
            } else {
                "Failed to create per-slot grammar sampler."
            };
            slot.fail(message);
            false
        }
    }
}