use super::session::InferenceSession;
use crate::error::{AmbiError, Result};
use crate::llm::providers::llama_cpp::vision::VisionContext;
use crate::types::config::LlamaEngineConfig;
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel};
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::token::LlamaToken;
use log::{debug, info, warn};
impl InferenceSession {
pub(crate) fn run_inference<F>(
prompt: &str,
images: &[String],
vision_ctx: Option<&VisionContext>,
model: &LlamaModel,
context: &mut LlamaContext,
batch: &mut LlamaBatch,
session: &mut InferenceSession,
cfg: &LlamaEngineConfig,
mut callback: F,
) -> Result<()>
where
F: FnMut(String) -> bool,
{
debug!("\n{}\n========================================", prompt);
let snapshot = session.snapshot();
let current_tokens = Self::tokenize_prompt(model, prompt)?;
Self::validate_prompt_length(¤t_tokens, cfg)?;
let dynamic_max_tokens = Self::calculate_max_tokens(¤t_tokens, cfg)?;
let match_len = Self::handle_kv_cache(session, context, ¤t_tokens);
Self::process_images(images, vision_ctx)?;
Self::eval_new_tokens(
session,
context,
batch,
cfg,
¤t_tokens,
match_len,
&snapshot,
)?;
let mut sampler = Self::create_sampler(cfg);
Self::generation_loop(
session,
model,
context,
batch,
cfg,
&mut sampler,
dynamic_max_tokens,
&snapshot,
&mut callback,
)
}
fn tokenize_prompt(model: &LlamaModel, prompt: &str) -> Result<Vec<LlamaToken>> {
let tokens = model
.str_to_token(prompt, AddBos::Always)
.map_err(|e| AmbiError::EngineError(format!("Tokenize failed: {}", e)))?;
Ok(tokens.to_vec())
}
fn validate_prompt_length(tokens: &[LlamaToken], cfg: &LlamaEngineConfig) -> Result<()> {
if tokens.len() >= cfg.n_ctx as usize {
return Err(AmbiError::EngineError(format!(
"Prompt size ({} tokens) exceeds or equals n_ctx limit ({})",
tokens.len(),
cfg.n_ctx
)));
}
Ok(())
}
fn calculate_max_tokens(tokens: &[LlamaToken], cfg: &LlamaEngineConfig) -> Result<usize> {
let dynamic_max = std::cmp::min(
cfg.max_tokens as usize,
(cfg.n_ctx as usize).saturating_sub(tokens.len()),
);
if dynamic_max < 32 {
return Err(AmbiError::EngineError(format!(
"Insufficient token space left for generation (only {} tokens). \
Increase n_ctx or reduce prompt length.",
dynamic_max
)));
}
Ok(dynamic_max)
}
fn handle_kv_cache(
session: &mut InferenceSession,
context: &mut LlamaContext,
current_tokens: &[LlamaToken],
) -> usize {
let mut match_len = 0;
for (t1, t2) in session.history_tokens.iter().zip(current_tokens.iter()) {
if t1 == t2 {
match_len += 1;
} else {
break;
}
}
if match_len < session.history_tokens.len() {
let evicted_len = session.history_tokens.len() - match_len;
info!(
"Evicting {} tokens, applying KV‑cache shift to save evaluation cost.",
evicted_len
);
let p0 = match_len as u32;
let p1 = (match_len + evicted_len) as u32;
if let Err(e) = context.clear_kv_cache_seq(Some(0), Some(p0), Some(p1)) {
warn!(
"Failed to cleanly remove KV‑cache sequence: {}. Falling back to full reset.",
e
);
context.clear_kv_cache();
match_len = 0;
} else if let Err(e) =
context.kv_cache_seq_add(0, Some(p1), None, -(evicted_len as i32))
{
warn!(
"Failed to shift KV‑cache sequence: {}. Falling back to full reset.",
e
);
context.clear_kv_cache();
match_len = 0;
} else {
session.history_tokens.truncate(match_len);
session.pos = match_len as i32;
}
} else {
session.pos = match_len as i32;
}
match_len
}
fn process_images(images: &[String], vision_ctx: Option<&VisionContext>) -> Result<()> {
if images.is_empty() {
return Ok(());
}
match vision_ctx {
None => Err(AmbiError::EngineError(
"Multimodal input received, but no vision context is configured. \
Set `mmproj_path` or `integrated_vision` in LlamaEngineConfig."
.into(),
)),
Some(VisionContext::ExternalProjector { .. }) => Err(AmbiError::EngineError(
"External projector (mmproj) multimodal support is not yet implemented. \
It will be available in Ambi 0.3.0."
.into(),
)),
Some(VisionContext::Integrated) => Ok(()),
}
}
fn eval_new_tokens(
session: &mut InferenceSession,
context: &mut LlamaContext,
batch: &mut LlamaBatch,
cfg: &LlamaEngineConfig,
current_tokens: &[LlamaToken],
match_len: usize,
snapshot: &(Vec<LlamaToken>, Vec<u8>, i32),
) -> Result<()> {
let new_tokens = ¤t_tokens[match_len..];
let chunk_size = cfg.n_tokens;
let total_new = new_tokens.len();
let mut processed = 0;
for chunk in new_tokens.chunks(chunk_size) {
batch.clear();
for &t in chunk.iter() {
processed += 1;
let is_last = processed == total_new;
batch.add(t, session.pos, &[0], is_last).map_err(|e| {
Self::rollback(session, context, snapshot);
AmbiError::EngineError(format!("Batch add failed: {}", e))
})?;
session.pos += 1;
}
if !chunk.is_empty() {
context.decode(batch).map_err(|e| {
Self::rollback(session, context, snapshot);
AmbiError::EngineError(format!("Decoding failed: {}", e))
})?;
}
}
session.history_tokens = current_tokens.to_vec();
Ok(())
}
fn create_sampler(cfg: &LlamaEngineConfig) -> LlamaSampler {
LlamaSampler::chain_simple([
LlamaSampler::penalties(
cfg.penalty_last_n,
cfg.penalty_repeat,
cfg.penalty_freq,
cfg.penalty_present,
),
LlamaSampler::top_p(cfg.top_p, cfg.min_keep),
LlamaSampler::temp(cfg.temp),
LlamaSampler::dist(cfg.seed),
])
}
fn generation_loop<F>(
session: &mut InferenceSession,
model: &LlamaModel,
context: &mut LlamaContext,
batch: &mut LlamaBatch,
cfg: &LlamaEngineConfig,
sampler: &mut LlamaSampler,
dynamic_max_tokens: usize,
snapshot: &(Vec<LlamaToken>, Vec<u8>, i32),
callback: &mut F,
) -> Result<()>
where
F: FnMut(String) -> bool,
{
let mut decoded_count = 0;
loop {
let next_token = sampler.sample(context, batch.n_tokens() - 1);
sampler.accept(next_token);
if model.is_eog_token(next_token) || decoded_count >= dynamic_max_tokens {
break;
}
session.history_tokens.push(next_token);
if let Ok(bytes) = model.token_to_piece_bytes(next_token, cfg.buffer_size, true, None) {
session.utf8_buffer.extend_from_slice(&bytes);
let should_stop = match std::str::from_utf8(&session.utf8_buffer) {
Ok(valid_str) => {
let stop = !callback(valid_str.to_string());
if stop {
true
} else {
session.utf8_buffer.clear();
false
}
}
Err(e) => {
let valid_up_to = e.valid_up_to();
if valid_up_to > 0 {
let valid_str = unsafe {
std::str::from_utf8_unchecked(&session.utf8_buffer[..valid_up_to])
};
let stop = !callback(valid_str.to_string());
if stop {
true
} else {
session.utf8_buffer.drain(..valid_up_to);
false
}
} else {
false
}
}
};
if should_stop {
Self::rollback(session, context, snapshot);
return Ok(());
}
}
batch.clear();
batch
.add(next_token, session.pos, &[0], true)
.map_err(|e| {
Self::rollback(session, context, snapshot);
AmbiError::EngineError(format!("Batch add failed: {}", e))
})?;
context.decode(batch).map_err(|e| {
Self::rollback(session, context, snapshot);
AmbiError::EngineError(format!("Decoding failed: {}", e))
})?;
session.pos += 1;
decoded_count += 1;
}
debug!("Generation finished after {} new tokens.", decoded_count);
Ok(())
}
fn rollback(
session: &mut InferenceSession,
context: &mut LlamaContext,
snapshot: &(Vec<LlamaToken>, Vec<u8>, i32),
) {
session.restore(snapshot.clone());
context.clear_kv_cache();
}
}