use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use rig::completion::CompletionError;
use rig::streaming::RawStreamingChoice;
use tokio::sync::mpsc;
use crate::checkpoint::{
PersistentCtx, ensure_persistent_ctx, maybe_create_checkpoint, restore_or_clear,
};
use crate::error::LoadError;
#[cfg(feature = "mtmd")]
use crate::image::run_image_inference;
use crate::loader::{WorkerModel, fit_and_load_model};
use crate::prompt::build_prompt;
use crate::sampling::sample_tokens;
use crate::slot::{SlotEntry, get_common_prefix};
use crate::types::{
CheckpointParams, FitParams, InferenceCommand, InferenceParams, InferenceResult, KvCacheParams,
ResponseChannel, StreamChunk, StreamSender,
};
pub(crate) const CANCEL_ERR: &str = "inference cancelled";
enum LoopOutcome {
Reload(crate::types::ReloadRequest),
Shutdown,
}
pub(crate) struct RunCtx<'a, 'm> {
pub(crate) backend: &'m llama_cpp_2::llama_backend::LlamaBackend,
pub(crate) model: &'m llama_cpp_2::model::LlamaModel,
pub(crate) n_ctx: u32,
pub(crate) kv_cache: &'a KvCacheParams,
pub(crate) checkpoint_params: CheckpointParams,
#[cfg(feature = "mtmd")]
pub(crate) mtmd_ctx: Option<&'a llama_cpp_2::mtmd::MtmdContext>,
pub(crate) cancel: &'a AtomicBool,
}
pub(crate) struct WorkerInit<'a> {
pub(crate) model_path: &'a str,
pub(crate) mmproj_path: Option<&'a str>,
pub(crate) n_ctx: u32,
pub(crate) fit_params: &'a FitParams,
pub(crate) kv_cache_params: &'a KvCacheParams,
pub(crate) checkpoint_params: CheckpointParams,
pub(crate) cancel: Arc<AtomicBool>,
}
fn handle_until_reload<'m>(
backend: &'m llama_cpp_2::llama_backend::LlamaBackend,
wm: &'m WorkerModel,
checkpoint_params: CheckpointParams,
cancel: &AtomicBool,
rx: &mut mpsc::Receiver<InferenceCommand>,
) -> LoopOutcome {
let mut persistent: Option<PersistentCtx<'m>> = None;
while let Some(command) = rx.blocking_recv() {
match command {
InferenceCommand::Request(req) => {
let crate::types::InferenceRequest {
params,
response_channel,
} = req;
let ctx = RunCtx {
backend,
model: &wm.model,
n_ctx: wm.n_ctx,
kv_cache: &wm.kv_cache,
checkpoint_params,
#[cfg(feature = "mtmd")]
mtmd_ctx: wm.mtmd_ctx.as_ref(),
cancel,
};
match response_channel {
ResponseChannel::Completion(tx) => {
let result = run_inference(&ctx, &mut persistent, ¶ms, None);
let _ = tx.send(result);
}
ResponseChannel::Streaming(stream_tx) => {
let result =
run_inference(&ctx, &mut persistent, ¶ms, Some(&stream_tx));
match result {
Ok(result) => {
let _ = stream_tx.send(Ok(RawStreamingChoice::FinalResponse(
StreamChunk {
text: result.text,
prompt_tokens: Some(result.prompt_tokens),
completion_tokens: Some(result.completion_tokens),
cached_input_tokens: Some(result.cached_input_tokens),
},
)));
}
Err(e) => {
let _ = stream_tx.send(Err(CompletionError::ProviderError(e)));
}
}
}
}
}
InferenceCommand::Reload(reload) => return LoopOutcome::Reload(reload),
InferenceCommand::Shutdown => return LoopOutcome::Shutdown,
}
if cancel.load(Ordering::Relaxed) {
return LoopOutcome::Shutdown;
}
}
LoopOutcome::Shutdown
}
pub(crate) fn inference_worker(
init: WorkerInit<'_>,
init_tx: std::sync::mpsc::Sender<Result<(), LoadError>>,
rx: &mut mpsc::Receiver<InferenceCommand>,
) {
let backend = match crate::shared_backend() {
Ok(b) => b,
Err(e) => {
let _ = init_tx.send(Err(LoadError::BackendInit(e)));
return;
}
};
let logs_enabled = crate::llama_logs_enabled();
let mut wm = match fit_and_load_model(
backend,
init.model_path,
init.mmproj_path,
init.n_ctx,
init.fit_params,
init.kv_cache_params,
logs_enabled,
) {
Ok(wm) => wm,
Err(e) => {
let _ = init_tx.send(Err(e));
return;
}
};
let _ = init_tx.send(Ok(()));
let mut checkpoint_params = init.checkpoint_params;
let cancel = init.cancel;
while let LoopOutcome::Reload(reload) =
handle_until_reload(backend, &wm, checkpoint_params, &cancel, rx)
{
drop(wm);
let result = fit_and_load_model(
backend,
&reload.model_path,
reload.mmproj_path.as_deref(),
reload.n_ctx,
&reload.fit_params,
&reload.kv_cache_params,
logs_enabled,
);
match result {
Ok(new_wm) => {
wm = new_wm;
checkpoint_params = reload.checkpoint_params;
let _ = reload.result_tx.send(Ok(()));
}
Err(e) => {
let _ = reload.result_tx.send(Err(e));
return;
}
}
}
}
fn run_inference<'m>(
ctx: &RunCtx<'_, 'm>,
persistent: &mut Option<PersistentCtx<'m>>,
req: &InferenceParams,
stream_tx: Option<&StreamSender>,
) -> Result<InferenceResult, String> {
#[cfg(feature = "mtmd")]
{
let has_images = !req.prepared_request.images.is_empty();
if has_images && ctx.mtmd_ctx.is_some() {
return run_image_inference(ctx, persistent, req, stream_tx);
}
}
run_text_inference(ctx, persistent, req, stream_tx)
}
fn run_text_inference<'m>(
ctx: &RunCtx<'_, 'm>,
persistent: &mut Option<PersistentCtx<'m>>,
req: &InferenceParams,
stream_tx: Option<&StreamSender>,
) -> Result<InferenceResult, String> {
use llama_cpp_2::model::AddBos;
let prompt_build = build_prompt(ctx.model, &req.prepared_request)?;
let prompt = prompt_build.prompt.as_str();
let new_tokens = ctx
.model
.str_to_token(prompt, AddBos::Always)
.map_err(|e| format!("Tokenization failed: {e}"))?;
let prompt_len = new_tokens.len();
if prompt_len == 0 {
return Err("Empty prompt after tokenization".to_string());
}
if prompt_len > ctx.n_ctx as usize {
return Err(format!(
"Prompt {prompt_len} tokens exceeds n_ctx {}",
ctx.n_ctx
));
}
let new_entries: Vec<SlotEntry> = new_tokens.iter().map(|t| SlotEntry::Text(*t)).collect();
let cached = {
let p = ensure_persistent_ctx(ctx.backend, ctx.model, ctx.n_ctx, ctx.kv_cache, persistent)?;
get_common_prefix(&p.last_entries, &new_entries)
};
let phase1 = {
let p = ensure_persistent_ctx(ctx.backend, ctx.model, ctx.n_ctx, ctx.kv_cache, persistent)?;
prepare_prompt_decode(
p,
&new_tokens,
cached,
prompt_len,
ctx.checkpoint_params,
ctx.cancel,
)
};
let (mut batch, effective_cached) = match phase1 {
Ok(out) => out,
Err(e) if cached > 0 => {
log::warn!(
"prefix-cache decode failed (cached={cached}, prompt_len={prompt_len}): {e}. \
Falling back to fresh-context decode."
);
*persistent = None;
let retry = {
let p = ensure_persistent_ctx(
ctx.backend,
ctx.model,
ctx.n_ctx,
ctx.kv_cache,
persistent,
)?;
prepare_prompt_decode(
p,
&new_tokens,
0,
prompt_len,
ctx.checkpoint_params,
ctx.cancel,
)
};
match retry {
Ok(out) => out,
Err(e) => {
*persistent = None;
return Err(e);
}
}
}
Err(e) => {
*persistent = None;
return Err(e);
}
};
let p = ensure_persistent_ctx(ctx.backend, ctx.model, ctx.n_ctx, ctx.kv_cache, persistent)?;
p.last_entries = new_entries;
let prompt_tokens = prompt_len as u64;
let cached_tokens = effective_cached as u64;
let result = sample_tokens(
ctx.model,
&mut p.ctx,
&mut batch,
&prompt_build,
req,
stream_tx,
prompt_tokens,
cached_tokens,
&mut p.last_entries,
ctx.cancel,
);
if result.is_err() {
*persistent = None;
}
result
}
fn prepare_prompt_decode<'b>(
p: &mut PersistentCtx<'_>,
new_tokens: &[llama_cpp_2::token::LlamaToken],
cached: usize,
prompt_len: usize,
checkpoint_params: CheckpointParams,
cancel: &AtomicBool,
) -> Result<(llama_cpp_2::llama_batch::LlamaBatch<'b>, usize), String> {
use llama_cpp_2::llama_batch::LlamaBatch;
log::debug!(
"prefix-cache: prompt_len={prompt_len} last_entries.len={} cached={cached} trim_unsupported={} checkpoints={}",
p.last_entries.len(),
p.trim_unsupported,
p.checkpoint_count(),
);
let mut effective_cached = cached;
if cached < p.last_entries.len() {
if p.trim_unsupported {
effective_cached = restore_or_clear(p, cached);
} else {
let removed = p
.ctx
.clear_kv_cache_seq(Some(0), Some(cached as u32), None)
.map_err(|e| format!("KV cache trim failed: {e:?}"))?;
if removed {
p.retain_checkpoints_below(cached);
} else {
log::info!(
"partial KV-cache trim not supported by this model \
(likely recurrent/hybrid). Routing rollbacks through checkpoint restore."
);
p.trim_unsupported = true;
effective_cached = restore_or_clear(p, cached);
}
}
} else {
p.retain_checkpoints_below(cached.max(1));
}
let prompt_batch_limit = p.ctx.n_batch().max(1) as usize;
let mut batch = LlamaBatch::new(prompt_batch_limit, 1);
if effective_cached < prompt_len {
let suffix = &new_tokens[effective_cached..];
for (chunk_index, chunk) in suffix.chunks(prompt_batch_limit).enumerate() {
if cancel.load(Ordering::Relaxed) {
return Err(CANCEL_ERR.to_string());
}
batch.clear();
for (offset, token) in chunk.iter().copied().enumerate() {
let abs = effective_cached + chunk_index * prompt_batch_limit + offset;
let is_last_prompt_token = abs + 1 == prompt_len;
batch
.add(token, abs as i32, &[0], is_last_prompt_token)
.map_err(|e| format!("Batch add failed: {e}"))?;
}
if batch.n_tokens() == 0 {
return Err(format!(
"BUG: empty prompt batch at chunk {chunk_index} (suffix.len={}, prompt_batch_limit={})",
suffix.len(),
prompt_batch_limit,
));
}
p.ctx
.decode(&mut batch)
.map_err(|e| format!("Prompt decode failed: {e}"))?;
let n_tokens_decoded =
effective_cached + chunk_index * prompt_batch_limit + chunk.len();
maybe_create_checkpoint(p, checkpoint_params, n_tokens_decoded, prompt_len);
}
} else {
let removed = p
.ctx
.clear_kv_cache_seq(Some(0), Some((prompt_len - 1) as u32), None)
.map_err(|e| format!("KV cache trim failed: {e:?}"))?;
if !removed {
return Err(format!(
"KV cache trim (rollback) returned false at pos {}",
prompt_len - 1
));
}
batch.clear();
batch
.add(
new_tokens[prompt_len - 1],
(prompt_len - 1) as i32,
&[0],
true,
)
.map_err(|e| format!("Batch add failed: {e}"))?;
p.ctx
.decode(&mut batch)
.map_err(|e| format!("Prompt decode failed: {e}"))?;
}
Ok((batch, effective_cached))
}