ambi 0.1.7

A flexible, multi-backend, customizable AI agent framework, entirely based on Rust.
Documentation
// src/llm/providers/llama_cpp/engine/inference.rs

use super::session::InferenceSession;
use crate::error::{AmbiError, Result};
use crate::types::config::LlamaEngineConfig;
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel};
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::token::LlamaToken;
use log::{debug, info, warn};

impl InferenceSession {
    /// Execute a full completion loop for the given prompt.
    ///
    /// # Parameters
    /// - `prompt`: the formatted prompt to be tokenized and processed.
    /// - `cfg`: static engine configuration (temperature, penalties, etc.).
    /// - `callback`: invoked for every successfully decoded UTF-8 string.
    ///   Returning `false` signals early termination (e.g. receiver dropped).
    ///
    /// # Returns
    /// `Ok(())` on normal completion (end-of-generation or callback
    /// termination), otherwise an `AmbiError::EngineError` describing the
    /// failure point.
    pub(crate) fn run_inference<F>(
        prompt: &str,
        model: &LlamaModel,
        context: &mut LlamaContext,
        mut batch: &mut LlamaBatch,
        session: &mut InferenceSession,
        cfg: &LlamaEngineConfig,
        mut callback: F,
    ) -> Result<()>
    where
        F: FnMut(String) -> bool,
    {
        // ----- 1. Tokenize prompt --------------------------------------------------
        debug!("\n{}\n========================================", prompt);
        let tokens_list = model
            .str_to_token(prompt, AddBos::Always)
            .map_err(|e| AmbiError::EngineError(format!("Tokenize failed: {}", e)))?;
        let current_tokens: Vec<LlamaToken> = tokens_list.to_vec();

        // ----- 2. Validate prompt size vs context window --------------------------
        if current_tokens.len() >= cfg.n_ctx as usize {
            return Err(AmbiError::EngineError(format!(
                "Prompt size ({} tokens) exceeds or equals n_ctx limit ({})",
                current_tokens.len(),
                cfg.n_ctx
            )));
        }

        let dynamic_max_tokens = std::cmp::min(
            cfg.max_tokens as usize,
            (cfg.n_ctx as usize).saturating_sub(current_tokens.len()),
        );

        if dynamic_max_tokens < 32 {
            return Err(AmbiError::EngineError(format!(
                "Insufficient token space left for generation (only {} tokens). \
                 Increase n_ctx or reduce prompt length.",
                dynamic_max_tokens
            )));
        }

        // ----- 3. KV‑cache eviction optimisation ----------------------------------
        // Compare the new prompt with the previously cached history to avoid
        // re‑evaluating the common prefix.
        let mut match_len = 0;
        for (t1, t2) in session.history_tokens.iter().zip(current_tokens.iter()) {
            if t1 == t2 {
                match_len += 1;
            } else {
                break;
            }
        }

        // If the common prefix is shorter than our full history, we must
        // remove the divergent suffix from the KV cache and shift remaining
        // entries.
        if match_len < session.history_tokens.len() {
            let evicted_len = session.history_tokens.len() - match_len;
            info!(
                "Evicting {} tokens, applying KV‑cache shift to save evaluation cost.",
                evicted_len
            );

            let p0 = match_len as u32; // start of the range to remove
            let p1 = (match_len + evicted_len) as u32; // exclusive end

            if let Err(e) = context.clear_kv_cache_seq(Some(0), Some(p0), Some(p1)) {
                warn!(
                    "Failed to cleanly remove KV‑cache sequence: {}. Falling back to full reset.",
                    e
                );
                context.clear_kv_cache();
                match_len = 0;
            } else if let Err(e) =
                context.kv_cache_seq_add(0, Some(p1), None, -(evicted_len as i32))
            {
                warn!(
                    "Failed to shift KV‑cache sequence: {}. Falling back to full reset.",
                    e
                );
                context.clear_kv_cache();
                match_len = 0;
            } else {
                // Shift successful; history is now the common prefix.
                session.history_tokens.truncate(match_len);
                // match_len stays as the new effective length.
            }
        }

        // ----- 4. Evaluate new tokens (the suffix that is not cached) -------------
        session.pos = match_len as i32;
        let new_tokens = &current_tokens[match_len..];

        let chunk_size = cfg.n_tokens;
        let total_new = new_tokens.len();
        let mut processed = 0;

        for chunk in new_tokens.chunks(chunk_size) {
            batch.clear();
            for &t in chunk.iter() {
                processed += 1;
                let is_last = processed == total_new;
                batch
                    .add(t, session.pos, &[0], is_last)
                    .map_err(|e| AmbiError::EngineError(format!("Batch add failed: {}", e)))?;
                session.pos += 1;
            }
            if !chunk.is_empty() {
                context
                    .decode(&mut batch)
                    .map_err(|e| AmbiError::EngineError(format!("Decoding failed: {}", e)))?;
            }
        }

        // Now our history fully reflects the current prompt.
        session.history_tokens = current_tokens;

        // ----- 5. Configure sampler chain -----------------------------------------
        let mut sampler = LlamaSampler::chain_simple([
            LlamaSampler::penalties(
                cfg.penalty_last_n,
                cfg.penalty_repeat,
                cfg.penalty_freq,
                cfg.penalty_present,
            ),
            LlamaSampler::top_p(cfg.top_p, cfg.min_keep),
            LlamaSampler::temp(cfg.temp),
            LlamaSampler::dist(cfg.seed),
        ]);

        // ----- 6. Generation loop -------------------------------------------------
        let mut decoded_count = 0;

        loop {
            let next_token = sampler.sample(&context, batch.n_tokens() - 1);
            sampler.accept(next_token);

            // Check for end-of-generation or length limit.
            if model.is_eog_token(next_token) || decoded_count >= dynamic_max_tokens {
                break;
            }

            // Record token for future KV‑cache matching.
            session.history_tokens.push(next_token);

            // Decode token bytes and assemble valid UTF-8 strings.
            if let Ok(bytes) = model.token_to_piece_bytes(next_token, cfg.buffer_size, true, None) {
                session.utf8_buffer.extend_from_slice(&bytes);

                match std::str::from_utf8(&session.utf8_buffer) {
                    Ok(valid_str) => {
                        if !callback(valid_str.to_string()) {
                            // Callback requested early stop (e.g. stream dropped).
                            break;
                        }
                        session.utf8_buffer.clear();
                    }
                    Err(e) => {
                        let valid_up_to = e.valid_up_to();
                        if valid_up_to > 0 {
                            // Emit the valid prefix even if the buffer is incomplete.
                            let valid_str = unsafe {
                                std::str::from_utf8_unchecked(&session.utf8_buffer[..valid_up_to])
                            };
                            if !callback(valid_str.to_string()) {
                                break;
                            }
                            session.utf8_buffer.drain(..valid_up_to);
                        }
                        // Otherwise keep the incomplete bytes for the next token.
                    }
                }
            } // If token_to_piece_bytes fails we simply skip this token (rare edge case).

            // Prepare the batch for the next iteration.
            batch.clear();
            batch
                .add(next_token, session.pos, &[0], true)
                .map_err(|e| AmbiError::EngineError(format!("Batch add failed: {}", e)))?;
            context
                .decode(&mut batch)
                .map_err(|e| AmbiError::EngineError(format!("Decoding failed: {}", e)))?;

            session.pos += 1;
            decoded_count += 1;
        }

        debug!("Generation finished after {} new tokens.", decoded_count);
        Ok(())
    }
}