#![allow(unsafe_code)]
use crate::backend::{
AcceleratorInfo, AcceleratorKind, Backend, BackendCapabilities, EmbedError, EmbedResult,
GenerateError, TokenEvent, TokenEventV2, TokenStream, TokenStreamV2,
};
use crate::ffi;
use crate::llamacpp::chat_template::Gemma4Renderer;
use crate::llamacpp::loader::{ModelHandle, ModelLoadError, load_model};
use crate::llamacpp::mtmd::{Bitmap, Mtmd, MtmdConfig, MtmdError};
use crate::llamacpp::tool_parser::{Output as TokenOutput, ToolCallParser};
use async_trait::async_trait;
use base64::Engine as _;
use inferd_proto::embed::{EmbedResolved, EmbedUsage};
use inferd_proto::v2::{Attachment, ResolvedV2, StopReasonV2, UsageV2};
use inferd_proto::{Resolved, StopReason, Usage};
use std::ffi::CString;
use std::ptr::{self, NonNull};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, Once};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, warn};
static LLAMA_BACKEND_INIT: Once = Once::new();
#[derive(Debug, thiserror::Error)]
pub enum LlamaCppError {
#[error("load: {0}")]
Load(#[from] ModelLoadError),
#[error("llama_init_from_model returned null")]
ContextInit,
#[error("sampler chain init failed")]
Sampler,
#[error("tokenize failed")]
Tokenize,
#[error("llama_decode failed: {0}")]
Decode(i32),
#[error("mtmd: {0}")]
Mtmd(#[from] MtmdError),
#[error("v2 request requires mmproj but none was configured")]
NoMmproj,
#[error("chat template: {0}")]
Render(String),
#[error("attachment base64 decode failed for {0:?}")]
Base64(String),
}
impl From<LlamaCppError> for GenerateError {
fn from(e: LlamaCppError) -> Self {
GenerateError::Internal(e.to_string())
}
}
#[derive(Debug, Clone)]
pub struct LlamaCppConfig {
pub model_path: std::path::PathBuf,
pub model_sha256: Option<[u8; 32]>,
pub n_ctx: u32,
pub n_gpu_layers: i32,
pub seed: u32,
pub mmproj_path: Option<std::path::PathBuf>,
pub mmproj_sha256: Option<[u8; 32]>,
pub embed: bool,
pub embed_pooling: Option<i32>,
pub embed_n_ctx: u32,
}
impl Default for LlamaCppConfig {
fn default() -> Self {
Self {
model_path: std::path::PathBuf::new(),
model_sha256: None,
n_ctx: 8192,
n_gpu_layers: 0,
seed: 0xDEADBEEF,
mmproj_path: None,
mmproj_sha256: None,
embed: false,
embed_pooling: None,
embed_n_ctx: 2048,
}
}
}
struct ContextHandle {
ptr: NonNull<ffi::llama_context>,
}
unsafe impl Send for ContextHandle {}
unsafe impl Sync for ContextHandle {}
impl Drop for ContextHandle {
fn drop(&mut self) {
unsafe { ffi::llama_free(self.ptr.as_ptr()) };
}
}
pub struct LlamaCpp {
name: &'static str,
ready: AtomicBool,
seed: u32,
accelerator: AcceleratorInfo,
model_label: String,
state: Arc<Mutex<State>>,
}
const fn compile_time_accelerator_kind() -> AcceleratorKind {
if cfg!(feature = "cuda") {
AcceleratorKind::Cuda
} else if cfg!(feature = "metal") {
AcceleratorKind::Metal
} else if cfg!(feature = "vulkan") {
AcceleratorKind::Vulkan
} else if cfg!(feature = "rocm") {
AcceleratorKind::Rocm
} else {
AcceleratorKind::Cpu
}
}
struct State {
model: ModelHandle,
ctx: ContextHandle,
mtmd: Option<Mtmd>,
caps_v2: Option<BackendCapabilitiesV2>,
embed: Option<EmbedContext>,
}
struct EmbedContext {
ctx: ContextHandle,
n_embd: u32,
}
#[derive(Debug, Clone, Copy)]
struct BackendCapabilitiesV2 {
vision: bool,
audio: bool,
#[allow(dead_code)]
audio_sample_rate: Option<u32>,
}
impl LlamaCpp {
pub fn new(config: LlamaCppConfig) -> Result<Self, LlamaCppError> {
ensure_backend_init();
let model = load_model(
&config.model_path,
config.model_sha256.as_ref(),
config.n_gpu_layers,
)?;
let ctx_ptr = unsafe {
let mut params = ffi::llama_context_default_params();
params.n_ctx = config.n_ctx;
ffi::llama_init_from_model(model.as_ptr(), params)
};
let ctx = NonNull::new(ctx_ptr)
.map(|ptr| ContextHandle { ptr })
.ok_or(LlamaCppError::ContextInit)?;
let (mtmd, caps_v2) = match config.mmproj_path.as_deref() {
Some(mmproj) => {
if let Some(expected) = config.mmproj_sha256.as_ref() {
crate::llamacpp::loader::verify_mmproj_sha256(mmproj, expected)?;
}
let mtmd_ctx = unsafe { Mtmd::new(mmproj, model.as_ptr(), MtmdConfig::default())? };
let caps = BackendCapabilitiesV2 {
vision: mtmd_ctx.supports_vision(),
audio: mtmd_ctx.supports_audio(),
audio_sample_rate: mtmd_ctx.audio_sample_rate(),
};
(Some(mtmd_ctx), Some(caps))
}
None => (None, None),
};
let accelerator = AcceleratorInfo {
kind: compile_time_accelerator_kind(),
gpu_layers: config.n_gpu_layers.max(0) as u32,
};
let model_label = read_model_label(model.as_ptr(), &config.model_path);
let embed = if config.embed {
let embed_ctx_ptr = unsafe {
let mut params = ffi::llama_context_default_params();
params.n_ctx = config.embed_n_ctx;
params.embeddings = true;
params.pooling_type = config.embed_pooling.unwrap_or(ffi::LLAMA_POOLING_TYPE_MEAN);
ffi::llama_init_from_model(model.as_ptr(), params)
};
let embed_ctx = NonNull::new(embed_ctx_ptr)
.map(|ptr| ContextHandle { ptr })
.ok_or(LlamaCppError::ContextInit)?;
let n_embd = unsafe { ffi::llama_n_embd(model.as_ptr()) };
if n_embd <= 0 {
return Err(LlamaCppError::ContextInit);
}
Some(EmbedContext {
ctx: embed_ctx,
n_embd: n_embd as u32,
})
} else {
None
};
Ok(Self {
name: "llamacpp",
ready: AtomicBool::new(true),
seed: config.seed,
accelerator,
model_label,
state: Arc::new(Mutex::new(State {
model,
ctx,
mtmd,
caps_v2,
embed,
})),
})
}
}
fn read_model_label(model: *const ffi::llama_model, path: &std::path::Path) -> String {
if let Some(name) = read_gguf_meta_string(model, "general.name") {
return name;
}
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
return stem.to_string();
}
"llamacpp".to_string()
}
fn read_gguf_meta_string(model: *const ffi::llama_model, key: &str) -> Option<String> {
let key_c = CString::new(key).ok()?;
let mut buf = [0i8; 256];
let needed = unsafe {
ffi::llama_model_meta_val_str(
model,
key_c.as_ptr(),
buf.as_mut_ptr() as *mut std::os::raw::c_char,
buf.len(),
)
};
if needed < 0 {
return None;
}
let needed = needed as usize;
if needed == 0 {
return None;
}
if needed < buf.len() {
let cstr = unsafe { std::ffi::CStr::from_ptr(buf.as_ptr() as *const _) };
return cstr.to_str().ok().map(|s| s.to_string());
}
let mut heap = vec![0i8; needed + 1];
let n = unsafe {
ffi::llama_model_meta_val_str(
model,
key_c.as_ptr(),
heap.as_mut_ptr() as *mut std::os::raw::c_char,
heap.len(),
)
};
if n < 0 {
return None;
}
let cstr = unsafe { std::ffi::CStr::from_ptr(heap.as_ptr() as *const _) };
cstr.to_str().ok().map(|s| s.to_string())
}
fn ensure_backend_init() {
LLAMA_BACKEND_INIT.call_once(|| {
unsafe { ffi::llama_backend_init() };
});
}
#[async_trait]
impl Backend for LlamaCpp {
fn name(&self) -> &str {
self.name
}
fn ready(&self) -> bool {
self.ready.load(Ordering::SeqCst)
}
fn capabilities(&self) -> BackendCapabilities {
let (snap, embed) = {
let guard = self.state.lock().expect("poisoned llamacpp state mutex");
(guard.caps_v2, guard.embed.is_some())
};
match snap {
Some(caps) => BackendCapabilities {
v2: true,
vision: caps.vision,
audio: caps.audio,
video: false,
tools: true,
thinking: true,
embed,
accelerator: self.accelerator,
},
None => BackendCapabilities {
embed,
accelerator: self.accelerator,
..BackendCapabilities::default()
},
}
}
async fn generate_v2(&self, req: ResolvedV2) -> Result<TokenStreamV2, GenerateError> {
if !self.ready() {
return Err(GenerateError::NotReady);
}
let renderer = Gemma4Renderer::new();
let rendered = renderer
.render(&req)
.map_err(|e| GenerateError::InvalidRequest(format!("render: {e}")))?;
let bitmaps: Vec<Bitmap> = rendered
.attachments
.iter()
.map(|att| build_bitmap(att))
.collect::<Result<_, _>>()
.map_err(|e| GenerateError::InvalidRequest(format!("attachment: {e}")))?;
let prompt = rendered.prompt;
let max_new = req.max_tokens.unwrap_or(crate::DEFAULT_V2_MAX_TOKENS);
let (tx, rx) = mpsc::channel(8);
let state = Arc::clone(&self.state);
let seed = self.seed;
let req_clone = req;
tokio::task::spawn_blocking(move || {
let outcome =
run_generation_v2(&state, &prompt, &bitmaps, &req_clone, max_new, seed, &tx);
if let Err(e) = outcome {
warn!(error = %e, "v2 generation aborted mid-stream");
}
});
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn generate(&self, req: Resolved) -> Result<TokenStream, GenerateError> {
if !self.ready() {
return Err(GenerateError::NotReady);
}
let prompt = render_chat_template(&self.state, &req.messages)
.ok_or_else(|| GenerateError::InvalidRequest("chat template render failed".into()))?;
let (tx, rx) = mpsc::channel(8);
let state = Arc::clone(&self.state);
let seed = self.seed;
let resolved = req;
let prompt_bytes = prompt;
tokio::task::spawn_blocking(move || {
let outcome = run_generation(&state, &prompt_bytes, &resolved, seed, &tx);
if let Err(e) = outcome {
warn!(error = %e, "generation aborted mid-stream");
}
});
Ok(Box::pin(ReceiverStream::new(rx)))
}
async fn embed(&self, req: EmbedResolved) -> Result<EmbedResult, EmbedError> {
if !self.ready() {
return Err(EmbedError::NotReady);
}
let task = req.task.clone();
let prefixed: Vec<String> = req
.input
.iter()
.map(|s| apply_task_prefix(task.as_ref(), s))
.collect();
let dimensions = req.dimensions;
let label = self.model_label.clone();
let state = Arc::clone(&self.state);
tokio::task::spawn_blocking(move || run_embed(&state, &prefixed, dimensions, label))
.await
.map_err(|e| EmbedError::Internal(format!("embed task join: {e}")))?
}
async fn stop(&self, _timeout: Duration) -> Result<(), GenerateError> {
self.ready.store(false, Ordering::SeqCst);
Ok(())
}
}
fn render_chat_template(
_state: &Arc<Mutex<State>>,
messages: &[inferd_proto::Message],
) -> Option<Vec<u8>> {
use inferd_proto::Role;
if messages.is_empty() {
return None;
}
let mut out = String::with_capacity(
messages.iter().map(|m| m.content.len()).sum::<usize>() + 64 * messages.len() + 32,
);
let mut pending_system: Option<&str> = None;
for m in messages {
match m.role {
Role::System => {
pending_system = Some(m.content.as_str());
}
Role::User => {
out.push_str("<start_of_turn>user\n");
if let Some(sys) = pending_system.take() {
out.push_str(sys);
out.push_str("\n\n");
}
out.push_str(&m.content);
out.push_str("<end_of_turn>\n");
}
Role::Assistant => {
out.push_str("<start_of_turn>model\n");
out.push_str(&m.content);
out.push_str("<end_of_turn>\n");
}
}
}
out.push_str("<start_of_turn>model\n");
Some(out.into_bytes())
}
fn run_generation(
state: &Arc<Mutex<State>>,
prompt: &[u8],
req: &Resolved,
seed: u32,
tx: &mpsc::Sender<TokenEvent>,
) -> Result<(), LlamaCppError> {
let guard = state.lock().expect("poisoned llamacpp state mutex");
let model = guard.model.as_ptr();
let ctx = guard.ctx.ptr.as_ptr();
let vocab = unsafe { ffi::llama_model_get_vocab(model) };
let prompt_tokens = tokenize(vocab, prompt, true, true)?;
let sampler = build_sampler_chain(vocab, req, seed)?;
let _sampler_guard = SamplerGuard { ptr: sampler };
unsafe {
let mem = ffi::llama_get_memory(ctx);
if !mem.is_null() {
ffi::llama_memory_clear(mem, true);
}
}
let mut tokens = prompt_tokens;
let mut batch = unsafe { ffi::llama_batch_get_one(tokens.as_mut_ptr(), tokens.len() as i32) };
let rc = unsafe { ffi::llama_decode(ctx, batch) };
if rc != 0 {
return Err(LlamaCppError::Decode(rc));
}
let prompt_len = tokens.len() as u32;
let mut completion_tokens: u32 = 0;
let max_new = req.max_tokens;
let mut buf = [0u8; 256];
for _ in 0..max_new {
let next: ffi::llama_token = unsafe { ffi::llama_sampler_sample(sampler, ctx, -1) };
let is_eog = unsafe { ffi::llama_vocab_is_eog(vocab, next) };
if is_eog {
let _ = tx.blocking_send(TokenEvent::Done {
stop_reason: StopReason::End,
usage: Usage {
prompt_tokens: prompt_len,
completion_tokens,
},
});
return Ok(());
}
unsafe { ffi::llama_sampler_accept(sampler, next) };
let piece = token_to_piece(vocab, next, &mut buf);
let text = String::from_utf8_lossy(piece).into_owned();
if tx.blocking_send(TokenEvent::Token(text)).is_err() {
debug!("generation cancelled (receiver dropped)");
return Ok(());
}
completion_tokens = completion_tokens.saturating_add(1);
let mut next_arr = [next];
batch = unsafe { ffi::llama_batch_get_one(next_arr.as_mut_ptr(), 1) };
let rc = unsafe { ffi::llama_decode(ctx, batch) };
if rc != 0 {
return Err(LlamaCppError::Decode(rc));
}
}
let _ = tx.blocking_send(TokenEvent::Done {
stop_reason: StopReason::Length,
usage: Usage {
prompt_tokens: prompt_len,
completion_tokens,
},
});
Ok(())
}
struct SamplerGuard {
ptr: *mut ffi::llama_sampler,
}
unsafe impl Send for SamplerGuard {}
impl Drop for SamplerGuard {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { ffi::llama_sampler_free(self.ptr) };
}
}
}
fn build_sampler_chain(
vocab: *const ffi::llama_vocab,
req: &Resolved,
seed: u32,
) -> Result<*mut ffi::llama_sampler, LlamaCppError> {
let chain = unsafe {
let params = ffi::llama_sampler_chain_default_params();
ffi::llama_sampler_chain_init(params)
};
if chain.is_null() {
return Err(LlamaCppError::Sampler);
}
if !req.grammar.is_empty() {
if let Err(e) = validate_grammar(&req.grammar) {
unsafe { ffi::llama_sampler_free(chain) };
return Err(e);
}
let grammar_c = CString::new(req.grammar.as_bytes()).map_err(|_| LlamaCppError::Sampler)?;
let root_c = CString::new("root").unwrap();
let g =
unsafe { ffi::llama_sampler_init_grammar(vocab, grammar_c.as_ptr(), root_c.as_ptr()) };
if g.is_null() {
unsafe { ffi::llama_sampler_free(chain) };
return Err(LlamaCppError::Sampler);
}
unsafe { ffi::llama_sampler_chain_add(chain, g) };
}
unsafe {
ffi::llama_sampler_chain_add(chain, ffi::llama_sampler_init_top_k(req.top_k as i32));
ffi::llama_sampler_chain_add(chain, ffi::llama_sampler_init_top_p(req.top_p as f32, 1));
ffi::llama_sampler_chain_add(chain, ffi::llama_sampler_init_temp(req.temperature as f32));
ffi::llama_sampler_chain_add(chain, ffi::llama_sampler_init_dist(seed));
}
Ok(chain)
}
pub const MAX_GRAMMAR_BYTES: usize = 64 * 1024;
pub const MAX_GRAMMAR_ALTERNATIONS: usize = 4096;
fn validate_grammar(grammar: &str) -> Result<(), LlamaCppError> {
if grammar.len() > MAX_GRAMMAR_BYTES {
return Err(LlamaCppError::Sampler);
}
let alternations = grammar.bytes().filter(|&b| b == b'|').count();
if alternations > MAX_GRAMMAR_ALTERNATIONS {
return Err(LlamaCppError::Sampler);
}
Ok(())
}
fn tokenize(
vocab: *const ffi::llama_vocab,
text: &[u8],
add_special: bool,
parse_special: bool,
) -> Result<Vec<ffi::llama_token>, LlamaCppError> {
let needed = unsafe {
ffi::llama_tokenize(
vocab,
text.as_ptr() as *const std::os::raw::c_char,
text.len() as i32,
ptr::null_mut(),
0,
add_special,
parse_special,
)
};
if needed >= 0 {
return Ok(vec![0; needed as usize]);
}
let need = (-needed) as usize;
let mut tokens = vec![0i32; need];
let written = unsafe {
ffi::llama_tokenize(
vocab,
text.as_ptr() as *const std::os::raw::c_char,
text.len() as i32,
tokens.as_mut_ptr(),
need as i32,
add_special,
parse_special,
)
};
if written < 0 {
return Err(LlamaCppError::Tokenize);
}
tokens.truncate(written as usize);
Ok(tokens)
}
fn token_to_piece(
vocab: *const ffi::llama_vocab,
token: ffi::llama_token,
buf: &mut [u8],
) -> &[u8] {
let n = unsafe {
ffi::llama_token_to_piece(
vocab,
token,
buf.as_mut_ptr() as *mut std::os::raw::c_char,
buf.len() as i32,
0,
true,
)
};
if n <= 0 {
return &[];
}
let n = (n as usize).min(buf.len());
&buf[..n]
}
fn build_bitmap(att: &Attachment) -> Result<Bitmap, LlamaCppError> {
use base64::engine::general_purpose::STANDARD;
match att {
Attachment::Image {
id,
width,
height,
bytes,
} => {
let raw = STANDARD
.decode(bytes)
.map_err(|_| LlamaCppError::Base64(id.clone()))?;
let bm = Bitmap::from_image_rgb(*width, *height, &raw)?;
Ok(bm)
}
Attachment::Audio { id, bytes, .. } => {
let raw = STANDARD
.decode(bytes)
.map_err(|_| LlamaCppError::Base64(id.clone()))?;
if raw.len() % 4 != 0 {
return Err(LlamaCppError::Render(format!(
"audio attachment {id:?}: byte length not a multiple of 4"
)));
}
let n_samples = raw.len() / 4;
let mut samples = Vec::with_capacity(n_samples);
for chunk in raw.chunks_exact(4) {
let arr: [u8; 4] = chunk.try_into().expect("chunks_exact 4 yields 4");
samples.push(f32::from_le_bytes(arr));
}
Ok(Bitmap::from_audio_f32(&samples)?)
}
Attachment::Video { id, .. } => Err(LlamaCppError::Render(format!(
"video attachment {id:?} not supported by the llamacpp adapter"
))),
Attachment::Unknown => Err(LlamaCppError::Render(
"unknown attachment kind in resolved request".into(),
)),
}
}
fn build_sampler_chain_v2(
_vocab: *const ffi::llama_vocab,
req: &ResolvedV2,
seed: u32,
) -> Result<*mut ffi::llama_sampler, LlamaCppError> {
let temperature = req.temperature.unwrap_or(1.0) as f32;
let top_p = req.top_p.unwrap_or(0.95) as f32;
let top_k = req.top_k.unwrap_or(64) as i32;
let chain = unsafe {
let params = ffi::llama_sampler_chain_default_params();
ffi::llama_sampler_chain_init(params)
};
if chain.is_null() {
return Err(LlamaCppError::Sampler);
}
unsafe {
ffi::llama_sampler_chain_add(chain, ffi::llama_sampler_init_top_k(top_k));
ffi::llama_sampler_chain_add(chain, ffi::llama_sampler_init_top_p(top_p, 1));
ffi::llama_sampler_chain_add(chain, ffi::llama_sampler_init_temp(temperature));
ffi::llama_sampler_chain_add(chain, ffi::llama_sampler_init_dist(seed));
}
Ok(chain)
}
fn run_generation_v2(
state: &Arc<Mutex<State>>,
prompt: &str,
bitmaps: &[Bitmap],
req: &ResolvedV2,
max_new: u32,
seed: u32,
tx: &mpsc::Sender<TokenEventV2>,
) -> Result<(), LlamaCppError> {
let guard = state.lock().expect("poisoned llamacpp state mutex");
let model = guard.model.as_ptr();
let ctx = guard.ctx.ptr.as_ptr();
let mtmd = guard.mtmd.as_ref().ok_or(LlamaCppError::NoMmproj)?;
let vocab = unsafe { ffi::llama_model_get_vocab(model) };
unsafe {
let mem = ffi::llama_get_memory(ctx);
if !mem.is_null() {
ffi::llama_memory_clear(mem, true);
}
}
let bitmap_refs: Vec<&Bitmap> = bitmaps.iter().collect();
let chunks = mtmd
.tokenize(prompt, &bitmap_refs)
.map_err(LlamaCppError::Mtmd)?;
let n_past =
unsafe { mtmd.eval_chunks(ctx, &chunks, 0, 0, 512, true) }.map_err(LlamaCppError::Mtmd)?;
let prompt_tokens = unsafe { crate::mtmd_ffi::mtmd_helper_get_n_tokens(chunks.raw()) } as u32;
drop(chunks);
let sampler = build_sampler_chain_v2(vocab, req, seed)?;
let _sampler_guard = SamplerGuard { ptr: sampler };
let mut completion_tokens: u32 = 0;
let mut buf = [0u8; 256];
let mut n_past = n_past;
let mut parser = ToolCallParser::new();
let mut emitted_tool_use = false;
for _ in 0..max_new {
let next: ffi::llama_token = unsafe { ffi::llama_sampler_sample(sampler, ctx, -1) };
let is_eog = unsafe { ffi::llama_vocab_is_eog(vocab, next) };
if is_eog {
for ev in parser.finish() {
if let Some(out_ev) = parser_output_to_event_v2(ev, &mut emitted_tool_use)
&& tx.blocking_send(out_ev).is_err()
{
return Ok(());
}
}
let stop = if emitted_tool_use {
StopReasonV2::ToolUse
} else {
StopReasonV2::EndTurn
};
let _ = tx.blocking_send(TokenEventV2::Done {
stop_reason: stop,
usage: UsageV2 {
input_tokens: prompt_tokens,
output_tokens: completion_tokens,
},
});
return Ok(());
}
unsafe { ffi::llama_sampler_accept(sampler, next) };
let piece = token_to_piece(vocab, next, &mut buf);
let text = String::from_utf8_lossy(piece).into_owned();
for ev in parser.push(&text) {
if let TokenOutput::Malformed(reason) = &ev {
warn!(reason = %reason, "tool-call parse failed; aborting generation");
return Err(LlamaCppError::Render(reason.clone()));
}
if let Some(out_ev) = parser_output_to_event_v2(ev, &mut emitted_tool_use)
&& tx.blocking_send(out_ev).is_err()
{
debug!("v2 generation cancelled (receiver dropped)");
return Ok(());
}
}
completion_tokens = completion_tokens.saturating_add(1);
let mut next_arr = [next];
let batch = unsafe { ffi::llama_batch_get_one(next_arr.as_mut_ptr(), 1) };
let rc = unsafe { ffi::llama_decode(ctx, batch) };
if rc != 0 {
return Err(LlamaCppError::Decode(rc));
}
n_past = n_past.saturating_add(1);
}
for ev in parser.finish() {
if let Some(out_ev) = parser_output_to_event_v2(ev, &mut emitted_tool_use)
&& tx.blocking_send(out_ev).is_err()
{
return Ok(());
}
}
let _ = tx.blocking_send(TokenEventV2::Done {
stop_reason: StopReasonV2::MaxTokens,
usage: UsageV2 {
input_tokens: prompt_tokens,
output_tokens: completion_tokens,
},
});
Ok(())
}
fn parser_output_to_event_v2(ev: TokenOutput, emitted_tool_use: &mut bool) -> Option<TokenEventV2> {
match ev {
TokenOutput::Text(text) => {
if text.is_empty() {
None
} else {
Some(TokenEventV2::Text(text))
}
}
TokenOutput::Thinking(text) => {
if text.is_empty() {
None
} else {
Some(TokenEventV2::Thinking(text))
}
}
TokenOutput::ToolUse {
tool_call_id,
name,
input,
} => {
*emitted_tool_use = true;
Some(TokenEventV2::ToolUse {
tool_call_id,
name,
input,
})
}
TokenOutput::Malformed(_) => None,
}
}
fn apply_task_prefix(task: Option<&inferd_proto::embed::EmbedTask>, input: &str) -> String {
use inferd_proto::embed::EmbedTask;
let prefix = match task {
None | Some(EmbedTask::Other) => return input.to_string(),
Some(EmbedTask::RetrievalQuery) => "task: search result | query: ",
Some(EmbedTask::RetrievalDocument) => "title: none | text: ",
Some(EmbedTask::Similarity) => "task: sentence similarity | query: ",
Some(EmbedTask::Classification) => "task: classification | query: ",
Some(EmbedTask::Clustering) => "task: clustering | query: ",
Some(EmbedTask::QuestionAnswering) => "task: question answering | query: ",
Some(EmbedTask::FactVerification) => "task: fact checking | query: ",
Some(EmbedTask::CodeRetrievalQuery) => "task: code retrieval | query: ",
};
let mut out = String::with_capacity(prefix.len() + input.len());
out.push_str(prefix);
out.push_str(input);
out
}
fn run_embed(
state: &Arc<Mutex<State>>,
inputs: &[String],
requested_dim: Option<u32>,
model_label: String,
) -> Result<EmbedResult, EmbedError> {
let guard = state.lock().expect("poisoned llamacpp state mutex");
let model = guard.model.as_ptr();
let embed = guard.embed.as_ref().ok_or(EmbedError::Unsupported)?;
let ctx = embed.ctx.ptr.as_ptr();
let n_embd = embed.n_embd as usize;
if let Some(d) = requested_dim
&& d as usize > n_embd
{
return Err(EmbedError::InvalidRequest(format!(
"dimensions {d} exceeds model n_embd {n_embd}"
)));
}
let out_dim = requested_dim.map(|d| d as usize).unwrap_or(n_embd);
let vocab = unsafe { ffi::llama_model_get_vocab(model) };
let mut input_tokens: u32 = 0;
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
for text in inputs {
unsafe {
let mem = ffi::llama_get_memory(ctx);
if !mem.is_null() {
ffi::llama_memory_clear(mem, true);
}
}
let mut tokens = tokenize(vocab, text.as_bytes(), true, false)
.map_err(|_| EmbedError::InvalidRequest("tokenize failed".into()))?;
if tokens.is_empty() {
return Err(EmbedError::InvalidRequest(
"input produced zero tokens".into(),
));
}
input_tokens = input_tokens.saturating_add(tokens.len() as u32);
let batch = unsafe { ffi::llama_batch_get_one(tokens.as_mut_ptr(), tokens.len() as i32) };
let rc = unsafe { ffi::llama_encode(ctx, batch) };
if rc != 0 {
return Err(EmbedError::Unavailable(format!(
"llama_encode failed: {rc}"
)));
}
let raw = unsafe { ffi::llama_get_embeddings_seq(ctx, 0) };
if raw.is_null() {
return Err(EmbedError::Unavailable(
"llama_get_embeddings_seq returned null".into(),
));
}
let slice = unsafe { std::slice::from_raw_parts(raw, n_embd) };
let mut vec: Vec<f32> = slice[..out_dim].to_vec();
l2_normalise(&mut vec);
embeddings.push(vec);
}
Ok(EmbedResult {
embeddings,
dimensions: out_dim as u32,
model: model_label,
usage: EmbedUsage { input_tokens },
})
}
fn l2_normalise(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
#[cfg(test)]
mod grammar_tests {
use super::*;
#[test]
fn small_grammar_is_accepted() {
let g = r#"root ::= "yes" | "no""#;
validate_grammar(g).unwrap();
}
#[test]
fn realistic_json_grammar_is_accepted() {
let g = r#"
root ::= object
object ::= "{" ws members? ws "}"
members ::= pair ("," ws pair)*
pair ::= string ws ":" ws value
value ::= object | string | number | "true" | "false" | "null"
string ::= "\"" [^"]* "\""
number ::= [0-9]+ ("." [0-9]+)?
ws ::= [ \t\n]*
"#;
validate_grammar(g).unwrap();
}
#[test]
fn oversized_grammar_is_rejected() {
let g = "x".repeat(MAX_GRAMMAR_BYTES + 1);
assert!(validate_grammar(&g).is_err());
}
#[test]
fn excessive_alternations_rejected() {
let g = "|".repeat(MAX_GRAMMAR_ALTERNATIONS + 1);
assert!(validate_grammar(&g).is_err());
}
#[test]
fn alternation_count_under_threshold_accepted() {
let g = "|".repeat(MAX_GRAMMAR_ALTERNATIONS);
validate_grammar(&g).unwrap();
}
}