mod sys;
pub use sys::{llama_log_get_verbosity, llama_log_set_verbosity};
use crate::runtime_adapter::llm::{
ChatMessage, GenerationConfig, GenerationOutput, LlmBackend, LlmConfig, LlmResult,
};
#[cfg(feature = "llm-llamacpp")]
use crate::runtime_adapter::llm_telemetry::StreamingTelemetry;
#[cfg(feature = "llm-llamacpp")]
use crate::runtime_adapter::streaming_postprocess::{
merge_stop_patterns, strip_thinking_tags, trim_partial_stop_suffix, truncate_at_first_stop,
StreamingTextFilter, CHAT_STOP_PATTERNS, CHAT_STOP_PATTERNS_BROKEN,
};
use crate::runtime_adapter::AdapterError;
use crate::tracing as xybrid_trace;
use std::sync::Mutex;
#[cfg(feature = "llm-llamacpp")]
use std::sync::Once;
#[cfg(feature = "llm-llamacpp")]
static BACKEND_INIT: Once = Once::new();
#[cfg(feature = "llm-llamacpp")]
pub struct LlamaCppBackend {
model: Option<sys::LlamaModel>,
context: Mutex<Option<sys::LlamaContext>>,
config: Option<LlmConfig>,
kv_state: Mutex<KvCacheState>,
}
#[cfg(feature = "llm-llamacpp")]
#[derive(Default)]
struct KvCacheState {
cached_tokens: Vec<i32>,
last_prefix_hit: Option<usize>,
}
#[cfg(feature = "llm-llamacpp")]
impl LlamaCppBackend {
pub fn new() -> LlmResult<Self> {
BACKEND_INIT.call_once(|| {
sys::llama_backend_init();
if let Ok(level) = std::env::var("XYBRID_LLAMACPP_VERBOSITY") {
if let Ok(v) = level.parse::<i32>() {
sys::llama_log_set_verbosity(v);
}
}
});
Ok(Self {
model: None,
context: Mutex::new(None),
config: None,
kv_state: Mutex::new(KvCacheState::default()),
})
}
}
#[cfg(feature = "llm-llamacpp")]
impl Drop for LlamaCppBackend {
fn drop(&mut self) {
let _ = self.context.get_mut().unwrap().take(); let _ = self.model.take();
}
}
#[cfg(feature = "llm-llamacpp")]
impl Default for LlamaCppBackend {
fn default() -> Self {
Self::new().expect("Failed to create LlamaCppBackend")
}
}
#[cfg(feature = "llm-llamacpp")]
impl LlamaCppBackend {
fn with_model_and_context<R, F>(&self, f: F) -> LlmResult<R>
where
F: FnOnce(&sys::LlamaModel, &sys::LlamaContext) -> LlmResult<R>,
{
let model = self.model.as_ref().ok_or_else(|| {
AdapterError::ModelNotLoaded("No model loaded. Call load() first.".to_string())
})?;
let ctx_guard = self
.context
.lock()
.map_err(|_| AdapterError::RuntimeError("Context mutex poisoned".to_string()))?;
let context = ctx_guard.as_ref().ok_or_else(|| {
AdapterError::ModelNotLoaded("No context. Call load() first.".to_string())
})?;
f(model, context)
}
fn prepare_kv_cache_and_get_tail(
&self,
model: &sys::LlamaModel,
context: &sys::LlamaContext,
new_tokens: &[i32],
max_new_tokens: usize,
) -> LlmResult<(Vec<i32>, usize)> {
let mut state = self
.kv_state
.lock()
.map_err(|_| AdapterError::RuntimeError("KV state mutex poisoned".to_string()))?;
if sys::llama_model_has_recurrent_state(model) {
sys::llama_kv_cache_clear(context);
state.cached_tokens = new_tokens.to_vec();
state.last_prefix_hit = Some(0);
return Ok((new_tokens.to_vec(), 0));
}
let n_ctx = sys::llama_n_ctx(context);
let prefix_len = compute_reusable_prefix_len(&state.cached_tokens, new_tokens);
let would_overflow = prefix_len
.saturating_add(new_tokens.len() - prefix_len)
.saturating_add(max_new_tokens)
>= n_ctx;
if prefix_len == 0 || would_overflow {
sys::llama_kv_cache_clear(context);
state.cached_tokens = new_tokens.to_vec();
state.last_prefix_hit = Some(0);
return Ok((new_tokens.to_vec(), 0));
}
sys::llama_kv_cache_seq_rm(context, 0, prefix_len);
let tail = new_tokens[prefix_len..].to_vec();
state.cached_tokens = new_tokens.to_vec();
state.last_prefix_hit = Some(prefix_len);
Ok((tail, prefix_len))
}
fn reset_kv_cache_after_failed_stream(&self, context: &sys::LlamaContext) {
sys::llama_kv_cache_clear(context);
self.clear_cached_prefix_state();
}
fn clear_cached_prefix_state(&self) {
if let Ok(mut state) = self.kv_state.lock() {
*state = KvCacheState::default();
}
}
}
#[cfg(feature = "llm-llamacpp")]
fn compute_reusable_prefix_len(cached: &[i32], new_tokens: &[i32]) -> usize {
let max_reuse = new_tokens.len().saturating_sub(1);
cached
.iter()
.zip(new_tokens.iter())
.take(max_reuse)
.take_while(|(a, b)| a == b)
.count()
}
#[cfg(feature = "llm-llamacpp")]
impl LlmBackend for LlamaCppBackend {
fn name(&self) -> &str {
"llama-cpp"
}
fn wire_label(&self) -> Option<&'static str> {
Some("llamacpp")
}
fn supported_formats(&self) -> Vec<&'static str> {
vec!["gguf"]
}
fn load(&mut self, config: &LlmConfig) -> LlmResult<()> {
use std::path::Path;
let model_path = Path::new(&config.model_path);
if !model_path.exists() {
return Err(AdapterError::ModelNotFound(config.model_path.clone()));
}
let gguf_path = if model_path.is_file() {
config.model_path.clone()
} else {
let gguf_files: Vec<_> = std::fs::read_dir(model_path)
.map_err(AdapterError::IOError)?
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.map(|ext| ext == "gguf")
.unwrap_or(false)
})
.collect();
if gguf_files.is_empty() {
return Err(AdapterError::ModelNotFound(format!(
"No .gguf files found in {}",
config.model_path
)));
}
gguf_files[0].path().to_string_lossy().to_string()
};
let model =
sys::llama_load_model_from_file(&gguf_path, config.gpu_layers).map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to load model from {}: {}. \
This may indicate an unsupported GGUF architecture — \
check that the vendored llama.cpp version supports this model's architecture. \
Enable verbose logging with XYBRID_LLAMACPP_VERBOSITY=4 for C++ error details.",
gguf_path, e
))
})?;
let context = sys::llama_new_context_with_model(
&model,
config.context_length,
config.n_threads,
config.n_batch,
config.flash_attn,
)
.map_err(|e| AdapterError::RuntimeError(format!("Failed to create context: {}", e)))?;
self.model = Some(model);
*self.context.get_mut().unwrap() = Some(context);
self.config = Some(config.clone());
Ok(())
}
fn is_loaded(&self) -> bool {
self.model.is_some() && self.context.lock().unwrap().is_some()
}
fn unload(&mut self) -> LlmResult<()> {
let _ = self.context.get_mut().unwrap().take();
let _ = self.model.take();
self.config = None;
*self.kv_state.get_mut().unwrap() = KvCacheState::default();
Ok(())
}
fn generate(
&self,
messages: &[ChatMessage],
config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
self.with_model_and_context(|model, context| {
let prompt = sys::llama_format_chat(model, messages)?;
let tokens = sys::llama_tokenize_special(model, &prompt, true)?;
let n_ctx = sys::llama_n_ctx(context);
if tokens.len() >= n_ctx {
return Err(AdapterError::InvalidInput(format!(
"Input too long: {} tokens exceeds context window of {} tokens. \
Reduce the prompt size or conversation history.",
tokens.len(),
n_ctx
)));
}
let (tail, n_past) =
self.prepare_kv_cache_and_get_tail(model, context, &tokens, config.max_tokens)?;
let prompt_token_count = tokens.len();
xybrid_trace::add_metadata("tokens_in", prompt_token_count.to_string());
let mut tel = StreamingTelemetry::new(prompt_token_count);
let stream_result = sys::llama_generate_streaming(
context,
model,
&tail,
config.max_tokens,
config.temperature,
config.top_p,
config.min_p,
config.top_k,
config.repetition_penalty,
&config.stop_sequences,
|_token_id, _token_text| {
tel.record_chunk();
Ok(())
},
n_past,
);
let (output_tokens, stopped_by_callback) = match stream_result {
Ok(result) => result,
Err(err) => {
self.reset_kv_cache_after_failed_stream(context);
return Err(err);
}
};
let fields = tel.finalize(output_tokens.len());
log::debug!(
target: "xybrid_core",
"Generated {} tokens. Last 10: {:?}",
output_tokens.len(),
output_tokens.iter().rev().take(10).collect::<Vec<_>>()
);
let mut text = sys::llama_detokenize(model, &output_tokens)?;
log::debug!(target: "xybrid_core", "LLM raw output ({} chars): {:?}", text.len(), &text[..text.len().min(200)]);
log::debug!(target: "xybrid_core", "First 100 bytes: {:?}", text.as_bytes().iter().take(100).collect::<Vec<_>>());
let final_stop_patterns = {
let mut extras: Vec<&str> = CHAT_STOP_PATTERNS.to_vec();
extras.extend_from_slice(CHAT_STOP_PATTERNS_BROKEN);
merge_stop_patterns(&config.stop_sequences, &extras)
};
log::debug!(target: "xybrid_core", "Searching for stop patterns: {:?}", final_stop_patterns);
let stopped_in_text = truncate_at_first_stop(&mut text, &final_stop_patterns);
let text = strip_thinking_tags(&text).trim().to_string();
let finish_reason = if stopped_in_text || stopped_by_callback {
"stop"
} else {
"length"
}
.to_string();
Ok(GenerationOutput {
text,
tokens_generated: output_tokens.len(),
generation_time_ms: fields.generation_time_ms,
tokens_per_second: fields.tokens_per_second,
finish_reason,
ttft_ms: fields.ttft_ms,
mean_itl_ms: fields.mean_itl_ms,
p95_itl_ms: fields.p95_itl_ms,
emitted_chunks: fields.emitted_chunks,
inter_chunk_ms: fields.inter_chunk_ms,
decode_tps: fields.decode_tps,
prefill_tps: fields.prefill_tps,
})
})
}
fn generate_raw(&self, prompt: &str, config: &GenerationConfig) -> LlmResult<GenerationOutput> {
self.with_model_and_context(|model, context| {
let tokens = sys::llama_tokenize_special(model, prompt, true)?;
let n_ctx = sys::llama_n_ctx(context);
if tokens.len() >= n_ctx {
return Err(AdapterError::InvalidInput(format!(
"Input too long: {} tokens exceeds context window of {} tokens.",
tokens.len(),
n_ctx
)));
}
let (tail, n_past) =
self.prepare_kv_cache_and_get_tail(model, context, &tokens, config.max_tokens)?;
let prompt_token_count = tokens.len();
xybrid_trace::add_metadata("tokens_in", prompt_token_count.to_string());
let mut tel = StreamingTelemetry::new(prompt_token_count);
let (output_tokens, stopped_by_callback) = sys::llama_generate_streaming(
context,
model,
&tail,
config.max_tokens,
config.temperature,
config.top_p,
config.min_p,
config.top_k,
config.repetition_penalty,
&config.stop_sequences,
|_token_id, _token_text| {
tel.record_chunk();
Ok(())
},
n_past,
)?;
let fields = tel.finalize(output_tokens.len());
let text = sys::llama_detokenize(model, &output_tokens)?;
let text = text.trim().to_string();
let finish_reason = if stopped_by_callback {
"stop"
} else {
"length"
}
.to_string();
Ok(GenerationOutput {
text,
tokens_generated: output_tokens.len(),
generation_time_ms: fields.generation_time_ms,
tokens_per_second: fields.tokens_per_second,
finish_reason,
ttft_ms: fields.ttft_ms,
mean_itl_ms: fields.mean_itl_ms,
p95_itl_ms: fields.p95_itl_ms,
emitted_chunks: fields.emitted_chunks,
inter_chunk_ms: fields.inter_chunk_ms,
decode_tps: fields.decode_tps,
prefill_tps: fields.prefill_tps,
})
})
}
fn generate_streaming(
&self,
messages: &[ChatMessage],
config: &GenerationConfig,
on_token: crate::runtime_adapter::llm::StreamingCallback<'_>,
) -> LlmResult<GenerationOutput> {
use crate::runtime_adapter::llm::PartialToken;
let mut on_token = on_token;
self.with_model_and_context(|model, context| {
let prompt = sys::llama_format_chat(model, messages)?;
let tokens = sys::llama_tokenize_special(model, &prompt, true)?;
let n_ctx = sys::llama_n_ctx(context);
if tokens.len() >= n_ctx {
return Err(AdapterError::InvalidInput(format!(
"Input too long: {} tokens exceeds context window of {} tokens. \
Reduce the prompt size or conversation history.",
tokens.len(),
n_ctx
)));
}
let (tail, n_past) =
self.prepare_kv_cache_and_get_tail(model, context, &tokens, config.max_tokens)?;
let prompt_token_count = tokens.len();
xybrid_trace::add_metadata("tokens_in", prompt_token_count.to_string());
let mut tel = StreamingTelemetry::new(prompt_token_count);
let stop_patterns = merge_stop_patterns(&config.stop_sequences, CHAT_STOP_PATTERNS);
let mut filter = StreamingTextFilter::new(stop_patterns.clone());
let mut token_index = 0usize;
let (output_tokens, stopped_by_callback) = sys::llama_generate_streaming(
context,
model,
&tail,
config.max_tokens,
config.temperature,
config.top_p,
config.min_p,
config.top_k,
config.repetition_penalty,
&stop_patterns, |token_id, token_text| {
tel.record_chunk();
if let Some(safe_text) = filter.push(token_text) {
let partial = PartialToken::new(
safe_text,
token_index,
filter.cumulative_emitted().to_string(),
)
.with_token_id(token_id as i64);
token_index += 1;
on_token(partial)?;
}
Ok(())
},
n_past,
)?;
let fields = tel.finalize(output_tokens.len());
let final_patterns = {
let mut extras: Vec<&str> = CHAT_STOP_PATTERNS.to_vec();
extras.extend_from_slice(CHAT_STOP_PATTERNS_BROKEN);
merge_stop_patterns(&config.stop_sequences, &extras)
};
let mut text = sys::llama_detokenize(model, &output_tokens)?;
let stopped_full = truncate_at_first_stop(&mut text, &final_patterns);
let trimmed_partial = trim_partial_stop_suffix(&mut text, &final_patterns);
let text = strip_thinking_tags(&text).trim().to_string();
let finish_reason =
if filter.is_stopped() || stopped_full || trimmed_partial || stopped_by_callback {
"stop".to_string()
} else {
"length".to_string()
};
if token_index > 0 {
let final_partial = PartialToken::new(String::new(), token_index, text.clone())
.with_finish_reason(&finish_reason);
let _ = on_token(final_partial);
}
Ok(GenerationOutput {
text,
tokens_generated: output_tokens.len(),
generation_time_ms: fields.generation_time_ms,
tokens_per_second: fields.tokens_per_second,
finish_reason,
ttft_ms: fields.ttft_ms,
mean_itl_ms: fields.mean_itl_ms,
p95_itl_ms: fields.p95_itl_ms,
emitted_chunks: fields.emitted_chunks,
inter_chunk_ms: fields.inter_chunk_ms,
decode_tps: fields.decode_tps,
prefill_tps: fields.prefill_tps,
})
})
}
fn supports_streaming(&self) -> bool {
true
}
fn memory_usage(&self) -> Option<u64> {
None
}
fn context_length(&self) -> Option<usize> {
self.config.as_ref().map(|c| c.context_length)
}
fn last_cached_prefix_len(&self) -> Option<usize> {
self.kv_state.lock().ok().and_then(|s| s.last_prefix_hit)
}
}
#[cfg(not(feature = "llm-llamacpp"))]
pub struct LlamaCppBackend;
#[cfg(not(feature = "llm-llamacpp"))]
impl LlamaCppBackend {
pub fn new() -> LlmResult<Self> {
Err(AdapterError::RuntimeError(
"llm-llamacpp feature not enabled. Build with --features llm-llamacpp".to_string(),
))
}
}
#[cfg(not(feature = "llm-llamacpp"))]
impl LlmBackend for LlamaCppBackend {
fn name(&self) -> &str {
"llama-cpp"
}
fn supported_formats(&self) -> Vec<&'static str> {
vec!["gguf"]
}
fn load(&mut self, _config: &LlmConfig) -> LlmResult<()> {
Err(AdapterError::RuntimeError(
"llm-llamacpp feature not enabled".to_string(),
))
}
fn is_loaded(&self) -> bool {
false
}
fn unload(&mut self) -> LlmResult<()> {
Ok(())
}
fn generate(
&self,
_messages: &[ChatMessage],
_config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
Err(AdapterError::RuntimeError(
"llm-llamacpp feature not enabled".to_string(),
))
}
fn generate_raw(
&self,
_prompt: &str,
_config: &GenerationConfig,
) -> LlmResult<GenerationOutput> {
Err(AdapterError::RuntimeError(
"llm-llamacpp feature not enabled".to_string(),
))
}
}
#[cfg(all(test, feature = "llm-llamacpp"))]
mod tests {
use super::*;
#[test]
fn backend_reports_true_streaming_for_sdk_cancellation_gate() {
let backend = LlamaCppBackend::new().unwrap();
assert!(
backend.supports_streaming(),
"llama.cpp must stay on the true streaming path so SDK abort checks can interrupt generation"
);
}
#[test]
fn failed_stream_resets_rust_kv_cache_state() {
let backend = LlamaCppBackend::new().unwrap();
{
let mut state = backend.kv_state.lock().unwrap();
state.cached_tokens = vec![1, 2, 3];
state.last_prefix_hit = Some(2);
}
backend.clear_cached_prefix_state();
let state = backend.kv_state.lock().unwrap();
assert!(
state.cached_tokens.is_empty(),
"failed streaming runs must not leave reusable prompt tokens behind"
);
assert_eq!(
state.last_prefix_hit, None,
"failed streaming runs must clear prefix-hit metadata"
);
}
#[test]
fn lcp_empty_inputs_return_zero() {
assert_eq!(compute_reusable_prefix_len(&[], &[1, 2, 3]), 0);
assert_eq!(compute_reusable_prefix_len(&[1, 2, 3], &[]), 0);
assert_eq!(compute_reusable_prefix_len(&[], &[]), 0);
}
#[test]
fn lcp_identical_prompts_keep_one_token_for_decoder() {
let same = vec![10, 20, 30, 40];
assert_eq!(compute_reusable_prefix_len(&same, &same), 3);
}
#[test]
fn lcp_partial_match_returns_shared_length() {
let cached = vec![1, 2, 3, 4, 99, 99];
let new = vec![1, 2, 3, 4, 50, 60, 70];
assert_eq!(compute_reusable_prefix_len(&cached, &new), 4);
}
#[test]
fn lcp_no_overlap_returns_zero() {
assert_eq!(compute_reusable_prefix_len(&[1, 2, 3], &[9, 8, 7]), 0);
}
#[test]
fn lcp_caps_at_new_tokens_minus_one() {
let cached = vec![1, 2, 3, 4, 5, 6];
let new = vec![1, 2, 3];
assert_eq!(compute_reusable_prefix_len(&cached, &new), 2);
}
#[test]
fn lcp_single_token_new_prompt_returns_zero() {
assert_eq!(compute_reusable_prefix_len(&[1, 2, 3], &[1]), 0);
}
}