ambi 0.2.4

A flexible, multi-backend, customizable AI agent framework, entirely based on Rust.
Documentation
// src/llm/providers/llama_cpp/inference.rs

use super::session::InferenceSession;
use crate::error::{AmbiError, Result};
use crate::llm::providers::llama_cpp::vision::VisionContext;
use crate::types::config::LlamaEngineConfig;
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel};
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::token::LlamaToken;
use log::{debug, info, warn};

impl InferenceSession {
    /// Execute a full completion loop for the given prompt.
    ///
    /// # Parameters
    /// - `prompt`: the formatted prompt to be tokenized and processed.
    /// - `cfg`: static engine configuration (temperature, penalties, etc.).
    /// - `callback`: invoked for every successfully decoded UTF-8 string.
    ///   Returning `false` signals early termination (e.g. receiver dropped).
    ///
    /// # Returns
    /// `Ok(())` on normal completion (end-of-generation or callback
    /// termination), otherwise an `AmbiError::EngineError` describing the
    /// failure point.
    pub(crate) fn run_inference<F>(
        prompt: &str,
        images: &[String],
        vision_ctx: Option<&VisionContext>,
        model: &LlamaModel,
        context: &mut LlamaContext,
        batch: &mut LlamaBatch,
        session: &mut InferenceSession,
        cfg: &LlamaEngineConfig,
        mut callback: F,
    ) -> Result<()>
    where
        F: FnMut(String) -> bool,
    {
        debug!("\n{}\n========================================", prompt);

        let snapshot = session.snapshot();

        let current_tokens = Self::tokenize_prompt(model, prompt)?;
        Self::validate_prompt_length(&current_tokens, cfg)?;
        let dynamic_max_tokens = Self::calculate_max_tokens(&current_tokens, cfg)?;

        let match_len = Self::handle_kv_cache(session, context, &current_tokens);

        Self::process_images(images, vision_ctx)?;

        Self::eval_new_tokens(
            session,
            context,
            batch,
            cfg,
            &current_tokens,
            match_len,
            &snapshot,
        )?;

        let mut sampler = Self::create_sampler(cfg);

        Self::generation_loop(
            session,
            model,
            context,
            batch,
            cfg,
            &mut sampler,
            dynamic_max_tokens,
            &snapshot,
            &mut callback,
        )
    }

    fn tokenize_prompt(model: &LlamaModel, prompt: &str) -> Result<Vec<LlamaToken>> {
        let tokens = model
            .str_to_token(prompt, AddBos::Always)
            .map_err(|e| AmbiError::EngineError(format!("Tokenize failed: {}", e)))?;
        Ok(tokens.to_vec())
    }

    fn validate_prompt_length(tokens: &[LlamaToken], cfg: &LlamaEngineConfig) -> Result<()> {
        if tokens.len() >= cfg.n_ctx as usize {
            return Err(AmbiError::EngineError(format!(
                "Prompt size ({} tokens) exceeds or equals n_ctx limit ({})",
                tokens.len(),
                cfg.n_ctx
            )));
        }
        Ok(())
    }

    fn calculate_max_tokens(tokens: &[LlamaToken], cfg: &LlamaEngineConfig) -> Result<usize> {
        let dynamic_max = std::cmp::min(
            cfg.max_tokens as usize,
            (cfg.n_ctx as usize).saturating_sub(tokens.len()),
        );
        if dynamic_max < 32 {
            return Err(AmbiError::EngineError(format!(
                "Insufficient token space left for generation (only {} tokens). \
                 Increase n_ctx or reduce prompt length.",
                dynamic_max
            )));
        }
        Ok(dynamic_max)
    }

    fn handle_kv_cache(
        session: &mut InferenceSession,
        context: &mut LlamaContext,
        current_tokens: &[LlamaToken],
    ) -> usize {
        let mut match_len = 0;
        for (t1, t2) in session.history_tokens.iter().zip(current_tokens.iter()) {
            if t1 == t2 {
                match_len += 1;
            } else {
                break;
            }
        }

        if match_len < session.history_tokens.len() {
            let evicted_len = session.history_tokens.len() - match_len;
            info!(
                "Evicting {} tokens, applying KV‑cache shift to save evaluation cost.",
                evicted_len
            );

            let p0 = match_len as u32;
            let p1 = (match_len + evicted_len) as u32;

            if let Err(e) = context.clear_kv_cache_seq(Some(0), Some(p0), Some(p1)) {
                warn!(
                    "Failed to cleanly remove KV‑cache sequence: {}. Falling back to full reset.",
                    e
                );
                context.clear_kv_cache();
                match_len = 0;
            } else if let Err(e) =
                context.kv_cache_seq_add(0, Some(p1), None, -(evicted_len as i32))
            {
                warn!(
                    "Failed to shift KV‑cache sequence: {}. Falling back to full reset.",
                    e
                );
                context.clear_kv_cache();
                match_len = 0;
            } else {
                session.history_tokens.truncate(match_len);
                session.pos = match_len as i32;
            }
        } else {
            session.pos = match_len as i32;
        }
        match_len
    }

    fn process_images(images: &[String], vision_ctx: Option<&VisionContext>) -> Result<()> {
        if images.is_empty() {
            return Ok(());
        }
        match vision_ctx {
            None => Err(AmbiError::EngineError(
                "Multimodal input received, but no vision context is configured. \
             Set `mmproj_path` or `integrated_vision` in LlamaEngineConfig."
                    .into(),
            )),
            Some(VisionContext::ExternalProjector { .. }) => Err(AmbiError::EngineError(
                "External projector (mmproj) multimodal support is not yet implemented. \
             It will be available in Ambi 0.3.0."
                    .into(),
            )),
            Some(VisionContext::Integrated) => Ok(()),
        }
    }

    fn eval_new_tokens(
        session: &mut InferenceSession,
        context: &mut LlamaContext,
        batch: &mut LlamaBatch,
        cfg: &LlamaEngineConfig,
        current_tokens: &[LlamaToken],
        match_len: usize,
        snapshot: &(Vec<LlamaToken>, Vec<u8>, i32),
    ) -> Result<()> {
        let new_tokens = &current_tokens[match_len..];
        let chunk_size = cfg.n_tokens;
        let total_new = new_tokens.len();
        let mut processed = 0;

        for chunk in new_tokens.chunks(chunk_size) {
            batch.clear();
            for &t in chunk.iter() {
                processed += 1;
                let is_last = processed == total_new;
                batch.add(t, session.pos, &[0], is_last).map_err(|e| {
                    Self::rollback(session, context, snapshot);
                    AmbiError::EngineError(format!("Batch add failed: {}", e))
                })?;
                session.pos += 1;
            }
            if !chunk.is_empty() {
                context.decode(batch).map_err(|e| {
                    Self::rollback(session, context, snapshot);
                    AmbiError::EngineError(format!("Decoding failed: {}", e))
                })?;
            }
        }
        session.history_tokens = current_tokens.to_vec();
        Ok(())
    }

    fn create_sampler(cfg: &LlamaEngineConfig) -> LlamaSampler {
        LlamaSampler::chain_simple([
            LlamaSampler::penalties(
                cfg.penalty_last_n,
                cfg.penalty_repeat,
                cfg.penalty_freq,
                cfg.penalty_present,
            ),
            LlamaSampler::top_p(cfg.top_p, cfg.min_keep),
            LlamaSampler::temp(cfg.temp),
            LlamaSampler::dist(cfg.seed),
        ])
    }

    fn generation_loop<F>(
        session: &mut InferenceSession,
        model: &LlamaModel,
        context: &mut LlamaContext,
        batch: &mut LlamaBatch,
        cfg: &LlamaEngineConfig,
        sampler: &mut LlamaSampler,
        dynamic_max_tokens: usize,
        snapshot: &(Vec<LlamaToken>, Vec<u8>, i32),
        callback: &mut F,
    ) -> Result<()>
    where
        F: FnMut(String) -> bool,
    {
        let mut decoded_count = 0;
        loop {
            let next_token = sampler.sample(context, batch.n_tokens() - 1);
            sampler.accept(next_token);

            if model.is_eog_token(next_token) || decoded_count >= dynamic_max_tokens {
                break;
            }

            session.history_tokens.push(next_token);

            if let Ok(bytes) = model.token_to_piece_bytes(next_token, cfg.buffer_size, true, None) {
                session.utf8_buffer.extend_from_slice(&bytes);

                let should_stop = match std::str::from_utf8(&session.utf8_buffer) {
                    Ok(valid_str) => {
                        let stop = !callback(valid_str.to_string());
                        if stop {
                            true
                        } else {
                            session.utf8_buffer.clear();
                            false
                        }
                    }
                    Err(e) => {
                        let valid_up_to = e.valid_up_to();
                        if valid_up_to > 0 {
                            let valid_str = unsafe {
                                std::str::from_utf8_unchecked(&session.utf8_buffer[..valid_up_to])
                            };
                            let stop = !callback(valid_str.to_string());
                            if stop {
                                true
                            } else {
                                session.utf8_buffer.drain(..valid_up_to);
                                false
                            }
                        } else {
                            false
                        }
                    }
                };

                if should_stop {
                    Self::rollback(session, context, snapshot);
                    return Ok(());
                }
            }

            batch.clear();
            batch
                .add(next_token, session.pos, &[0], true)
                .map_err(|e| {
                    Self::rollback(session, context, snapshot);
                    AmbiError::EngineError(format!("Batch add failed: {}", e))
                })?;
            context.decode(batch).map_err(|e| {
                Self::rollback(session, context, snapshot);
                AmbiError::EngineError(format!("Decoding failed: {}", e))
            })?;

            session.pos += 1;
            decoded_count += 1;
        }

        debug!("Generation finished after {} new tokens.", decoded_count);
        Ok(())
    }

    fn rollback(
        session: &mut InferenceSession,
        context: &mut LlamaContext,
        snapshot: &(Vec<LlamaToken>, Vec<u8>, i32),
    ) {
        session.restore(snapshot.clone());
        context.clear_kv_cache();
    }
}