use super::config::LlamaEngineConfig;
use super::session::InferenceSession;
use super::vision::VisionContext;
use crate::error::{AmbiError, Result};
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel};
use llama_cpp_2::sampling::LlamaSampler;
use llama_cpp_2::token::LlamaToken;
use log::{debug, info, warn};
#[cfg(feature = "mtmd")]
use llama_cpp_2::mtmd::MtmdInputText;
pub(crate) struct InferenceInput<'a> {
pub prompt: &'a str,
pub images: &'a [String],
pub vision_ctx: Option<&'a VisionContext>,
pub cfg: &'a LlamaEngineConfig,
}
impl InferenceSession {
pub(crate) fn run_inference<F>(
input: InferenceInput<'_>,
model: &LlamaModel,
context: &mut LlamaContext,
batch: &mut LlamaBatch,
session: &mut InferenceSession,
mut callback: F,
) -> Result<()>
where
F: FnMut(String) -> bool,
{
debug!(
"\n{}\n========================================",
input.prompt
);
let snapshot = session.snapshot();
let current_tokens = Self::tokenize_prompt(model, input.prompt)?;
Self::validate_prompt_length(¤t_tokens, input.cfg)?;
let dynamic_max_tokens = Self::calculate_max_tokens(¤t_tokens, input.cfg)?;
let match_len = if input.images.is_empty() {
Self::handle_kv_cache(session, context, ¤t_tokens)
} else {
0
};
Self::process_images(input.images, input.vision_ctx)?;
#[cfg(feature = "mtmd")]
if !input.images.is_empty()
&& matches!(
input.vision_ctx,
Some(VisionContext::ExternalProjector { .. })
)
{
return Self::multimodal_inference(
input,
model,
context,
batch,
session,
&mut callback,
);
}
Self::eval_new_tokens(
session,
context,
batch,
input.cfg,
¤t_tokens,
match_len,
&snapshot,
)?;
let mut sampler = Self::create_sampler(input.cfg);
Self::generation_loop(
session,
model,
context,
batch,
input.cfg,
&mut sampler,
dynamic_max_tokens,
&snapshot,
&mut callback,
)
}
fn tokenize_prompt(model: &LlamaModel, prompt: &str) -> Result<Vec<LlamaToken>> {
let tokens = model
.str_to_token(prompt, AddBos::Always)
.map_err(|e| AmbiError::EngineError(format!("Tokenize failed: {}", e)))?;
Ok(tokens.to_vec())
}
fn validate_prompt_length(tokens: &[LlamaToken], cfg: &LlamaEngineConfig) -> Result<()> {
if tokens.len() >= cfg.n_ctx as usize {
return Err(AmbiError::EngineError(format!(
"Prompt size ({} tokens) exceeds or equals n_ctx limit ({})",
tokens.len(),
cfg.n_ctx
)));
}
Ok(())
}
fn calculate_max_tokens(tokens: &[LlamaToken], cfg: &LlamaEngineConfig) -> Result<usize> {
let dynamic_max = std::cmp::min(
cfg.max_tokens as usize,
(cfg.n_ctx as usize).saturating_sub(tokens.len()),
);
if dynamic_max < 32 {
return Err(AmbiError::EngineError(format!(
"Insufficient token space left for generation (only {} tokens). \
Increase n_ctx or reduce prompt length.",
dynamic_max
)));
}
Ok(dynamic_max)
}
fn handle_kv_cache(
session: &mut InferenceSession,
context: &mut LlamaContext,
current_tokens: &[LlamaToken],
) -> usize {
let mut match_len = 0;
for (t1, t2) in session.history_tokens.iter().zip(current_tokens.iter()) {
if t1 == t2 {
match_len += 1;
} else {
break;
}
}
if match_len < session.history_tokens.len() {
let evicted_len = session.history_tokens.len() - match_len;
info!(
"Evicting {} tokens, applying KV‑cache shift to save evaluation cost.",
evicted_len
);
let p0 = match_len as u32;
let p1 = (match_len + evicted_len) as u32;
if let Err(e) = context.clear_kv_cache_seq(Some(0), Some(p0), Some(p1)) {
warn!(
"Failed to cleanly remove KV‑cache sequence: {}. Falling back to full reset.",
e
);
context.clear_kv_cache();
match_len = 0;
} else if let Err(e) =
context.kv_cache_seq_add(0, Some(p1), None, -(evicted_len as i32))
{
warn!(
"Failed to shift KV‑cache sequence: {}. Falling back to full reset.",
e
);
context.clear_kv_cache();
match_len = 0;
} else {
session.history_tokens.truncate(match_len);
session.pos = match_len as i32;
}
} else {
session.pos = match_len as i32;
}
match_len
}
fn process_images(images: &[String], vision_ctx: Option<&VisionContext>) -> Result<()> {
if images.is_empty() {
return Ok(());
}
match vision_ctx {
None => Err(AmbiError::EngineError(
"Multimodal input received, but no vision context is configured. \
Set `mmproj_path` or `integrated_vision` in LlamaEngineConfig."
.into(),
)),
#[cfg(feature = "mtmd")]
Some(VisionContext::ExternalProjector { .. }) => {
Ok(())
}
Some(VisionContext::Integrated) => Err(AmbiError::EngineError(
"Native integrated vision support is not yet implemented. \
To use multimodal models, enable the 'mtmd' feature and provide an 'mmproj_path'."
.into(),
)),
}
}
fn eval_new_tokens(
session: &mut InferenceSession,
context: &mut LlamaContext,
batch: &mut LlamaBatch,
cfg: &LlamaEngineConfig,
current_tokens: &[LlamaToken],
match_len: usize,
snapshot: &(Vec<LlamaToken>, Vec<u8>, i32),
) -> Result<()> {
let new_tokens = ¤t_tokens[match_len..];
let chunk_size = cfg.n_tokens;
let total_new = new_tokens.len();
let mut processed = 0;
for chunk in new_tokens.chunks(chunk_size) {
batch.clear();
for &t in chunk.iter() {
processed += 1;
let is_last = processed == total_new;
batch.add(t, session.pos, &[0], is_last).map_err(|e| {
Self::rollback(session, context, snapshot);
AmbiError::EngineError(format!("Batch add failed: {}", e))
})?;
session.pos += 1;
}
if !chunk.is_empty() {
context.decode(batch).map_err(|e| {
Self::rollback(session, context, snapshot);
AmbiError::EngineError(format!("Decoding failed: {}", e))
})?;
}
}
session.history_tokens = current_tokens.to_vec();
Ok(())
}
fn create_sampler(cfg: &LlamaEngineConfig) -> LlamaSampler {
LlamaSampler::chain_simple([
LlamaSampler::penalties(
cfg.penalty_last_n,
cfg.penalty_repeat,
cfg.penalty_freq,
cfg.penalty_present,
),
LlamaSampler::top_p(cfg.top_p, cfg.min_keep),
LlamaSampler::temp(cfg.temp),
LlamaSampler::dist(cfg.seed),
])
}
#[allow(clippy::too_many_arguments)]
fn generation_loop<F>(
session: &mut InferenceSession,
model: &LlamaModel,
context: &mut LlamaContext,
batch: &mut LlamaBatch,
cfg: &LlamaEngineConfig,
sampler: &mut LlamaSampler,
dynamic_max_tokens: usize,
snapshot: &(Vec<LlamaToken>, Vec<u8>, i32),
callback: &mut F,
) -> Result<()>
where
F: FnMut(String) -> bool,
{
let mut decoded_count = 0;
let mut logits_idx = -1;
loop {
let next_token = sampler.sample(context, logits_idx);
sampler.accept(next_token);
if model.is_eog_token(next_token) || decoded_count >= dynamic_max_tokens {
break;
}
session.history_tokens.push(next_token);
if let Ok(bytes) = model.token_to_piece_bytes(next_token, cfg.buffer_size, true, None) {
session.utf8_buffer.extend_from_slice(&bytes);
let should_stop = match std::str::from_utf8(&session.utf8_buffer) {
Ok(valid_str) => {
let stop = !callback(valid_str.to_string());
if stop {
true
} else {
session.utf8_buffer.clear();
false
}
}
Err(e) => {
let valid_up_to = e.valid_up_to();
if valid_up_to > 0 {
let valid_str = unsafe {
std::str::from_utf8_unchecked(&session.utf8_buffer[..valid_up_to])
};
let stop = !callback(valid_str.to_string());
if stop {
true
} else {
session.utf8_buffer.drain(..valid_up_to);
false
}
} else {
false
}
}
};
if should_stop {
Self::rollback(session, context, snapshot);
return Ok(());
}
}
batch.clear();
batch
.add(next_token, session.pos, &[0], true)
.map_err(|e| {
Self::rollback(session, context, snapshot);
AmbiError::EngineError(format!("Batch add failed: {}", e))
})?;
context.decode(batch).map_err(|e| {
Self::rollback(session, context, snapshot);
AmbiError::EngineError(format!("Decoding failed: {}", e))
})?;
logits_idx = 0;
session.pos += 1;
decoded_count += 1;
}
debug!("Generation finished after {} new tokens.", decoded_count);
Ok(())
}
fn rollback(
session: &mut InferenceSession,
context: &mut LlamaContext,
snapshot: &(Vec<LlamaToken>, Vec<u8>, i32),
) {
session.restore(snapshot.clone());
context.clear_kv_cache();
}
#[cfg(feature = "mtmd")]
fn multimodal_inference<F>(
input: InferenceInput<'_>,
model: &LlamaModel,
context: &mut LlamaContext,
batch: &mut LlamaBatch,
session: &mut InferenceSession,
callback: &mut F,
) -> Result<()>
where
F: FnMut(String) -> bool,
{
use llama_cpp_2::mtmd::MtmdInputChunkType;
let snapshot = session.snapshot();
session.reset();
context.clear_kv_cache();
let mtmd_ctx = match input.vision_ctx {
Some(VisionContext::ExternalProjector { mtmd_ctx }) => mtmd_ctx,
_ => {
return Err(AmbiError::EngineError(
"Missing MTMD context for multimodal inference".into(),
))
}
};
let bitmaps = input.vision_ctx.unwrap().create_bitmaps(input.images)?;
let bitmap_refs: Vec<_> = bitmaps.iter().collect();
let text_input = MtmdInputText {
text: input.prompt.to_string(),
add_special: true,
parse_special: true,
};
let chunks = mtmd_ctx
.tokenize(text_input, &bitmap_refs)
.map_err(|e| AmbiError::EngineError(format!("MTMD tokenize error: {}", e)))?;
let mut all_tokens = Vec::new();
for i in 0..chunks.len() {
if let Some(chunk) = chunks.get(i) {
if chunk.chunk_type() == MtmdInputChunkType::Text {
if let Some(tokens) = chunk.text_tokens() {
all_tokens.extend_from_slice(tokens);
}
}
}
}
let n_past = 0; let new_n_past = chunks
.eval_chunks(
mtmd_ctx,
context,
n_past,
0, input.cfg.n_tokens as i32, true, )
.map_err(|e| AmbiError::EngineError(format!("MTMD eval error: {}", e)))?;
session.history_tokens = all_tokens;
session.pos = new_n_past;
let mut sampler = Self::create_sampler(input.cfg);
let dynamic_max_tokens = Self::calculate_max_tokens(&session.history_tokens, input.cfg)?;
Self::generation_loop(
session,
model,
context,
batch,
input.cfg,
&mut sampler,
dynamic_max_tokens,
&snapshot,
callback,
)
}
}