use log::{debug, info, warn};
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
use super::template::PostprocessingStep;
use super::template::{
backend_label_from_template, quantization_label_from_metadata, span_kind_from_template,
stage_kind_from_task, ExecutionMode, ExecutionTemplate, ModelMetadata, PipelineStage,
};
use crate::conversation::ConversationContext;
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
use crate::ir::EnvelopeKind;
use crate::ir::{Envelope, MessageRole};
use crate::runtime_adapter::{AdapterError, ModelRuntime};
use crate::tracing as xybrid_trace;
use ndarray::ArrayD;
use std::collections::HashMap;
use std::path::Path;
use super::listener::ExecutionGuard;
fn mark_execution_terminal(guard: &ExecutionGuard, error: &AdapterError) {
if error.cloud_fallback_abort_reason().is_some() {
guard.set_controlled_abort();
} else {
guard.set_failed(error.to_string());
}
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
fn stamp_llm_span_cost_attribution(metadata: &ModelMetadata) {
let backend_hint = metadata.metadata.get("backend").and_then(|v| v.as_str());
if let Some(label) = backend_label_from_template(&metadata.execution_template, backend_hint) {
xybrid_trace::add_metadata("backend", label);
}
if let Some(quant) = quantization_label_from_metadata(metadata) {
xybrid_trace::add_metadata("quantization", quant);
}
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
fn stamp_llm_runtime_backend(adapter: &LlmRuntimeAdapter) {
if let Some(label) = adapter.wire_label() {
xybrid_trace::add_metadata("backend", label);
}
}
use crate::execution::session_factory::OnnxSessionFactory;
use crate::runtime_adapter::onnx::{
ExecutionProviderKind, ONNXSession, OnnxRuntime, SessionOptions,
};
#[cfg(feature = "candle")]
use crate::runtime_adapter::candle::CandleRuntime;
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
use crate::runtime_adapter::types::{ChatMessage, LlmConfig};
use crate::runtime_adapter::types::{GenerationConfig, StreamingCallback};
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
use crate::runtime_adapter::llm::LlmRuntimeAdapter;
use super::modes::{
execute_autoregressive_stage, execute_bert_inference, execute_single_shot_stage,
execute_tts_inference, execute_whisper_decoder_stage,
};
use super::postprocessing;
use super::preprocessing;
use super::types::{ExecutorResult, PreprocessedData, RawOutputs};
use super::voice_loader::TtsVoiceLoader;
pub struct TemplateExecutor {
runtimes: HashMap<String, Box<dyn ModelRuntime>>,
base_path: String,
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
llm_adapter_cache: Option<(String, LlmRuntimeAdapter)>,
#[cfg(not(any(feature = "llm-mistral", feature = "llm-llamacpp")))]
llm_adapter_cache: Option<()>,
}
impl TemplateExecutor {
pub fn new(base_path: &str) -> Self {
Self::with_runtimes(base_path, Self::default_runtimes())
}
pub fn with_base_path(base_path: &str) -> Self {
Self::new(base_path)
}
pub fn with_runtimes(
base_path: &str,
runtimes: HashMap<String, Box<dyn ModelRuntime>>,
) -> Self {
Self {
runtimes,
base_path: base_path.into(),
llm_adapter_cache: None,
}
}
pub fn default_runtimes() -> HashMap<String, Box<dyn ModelRuntime>> {
let mut runtimes: HashMap<String, Box<dyn ModelRuntime>> = HashMap::new();
runtimes.insert("onnx".to_string(), Box::new(OnnxRuntime::new()));
#[cfg(feature = "candle")]
runtimes.insert("candle".to_string(), Box::new(CandleRuntime::new()));
runtimes
}
pub fn register_runtime(&mut self, name: impl Into<String>, runtime: Box<dyn ModelRuntime>) {
self.runtimes.insert(name.into(), runtime);
}
pub fn get_runtime(&self, name: &str) -> Option<&dyn ModelRuntime> {
self.runtimes.get(name).map(|r| r.as_ref())
}
pub fn list_runtimes(&self) -> Vec<&str> {
self.runtimes.keys().map(|s| s.as_str()).collect()
}
pub fn execute(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
let guard = ExecutionGuard::new_silent(&metadata.model_id, "execute");
let result = self.execute_impl(metadata, input, config);
if let Err(e) = &result {
mark_execution_terminal(&guard, e);
}
result
}
fn execute_impl(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
#[allow(unused_variables)] config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
debug!(
target: "xybrid_core",
"TemplateExecutor.execute START: model_id={}, template={:?}",
metadata.model_id,
std::mem::discriminant(&metadata.execution_template)
);
info!(
target: "xybrid_core",
"Executing model: {} v{}",
metadata.model_id,
metadata.version
);
debug!(
target: "xybrid_core",
"Input envelope kind: {}",
input.kind_str()
);
let _exec_span = xybrid_trace::SpanGuard::new(format!("execute:{}", metadata.model_id));
xybrid_trace::add_metadata("model_id", &metadata.model_id);
xybrid_trace::add_metadata("version", &metadata.version);
if let Some(task) = metadata.metadata.get("task").and_then(|v| v.as_str()) {
if let Some(kind) = stage_kind_from_task(task) {
xybrid_trace::add_metadata("stage_kind", kind);
}
xybrid_trace::add_metadata("task", task);
}
xybrid_trace::add_metadata(
"span_kind",
span_kind_from_template(&metadata.execution_template),
);
let backend_hint = metadata.metadata.get("backend").and_then(|v| v.as_str());
if let Some(label) = backend_label_from_template(&metadata.execution_template, backend_hint)
{
xybrid_trace::add_metadata("backend", label);
}
if let Some(quant) = quantization_label_from_metadata(metadata) {
xybrid_trace::add_metadata("quantization", quant);
}
if let ExecutionTemplate::ModelGraph { stages, config } = &metadata.execution_template {
info!(
target: "xybrid_core",
"Executing model graph with {} stages",
stages.len()
);
let _span = xybrid_trace::SpanGuard::new("model_graph_inference");
xybrid_trace::add_metadata("stages", stages.len().to_string());
let preprocessed = self.run_preprocessing(metadata, input)?;
let raw_outputs = self.execute_pipeline(stages, config, preprocessed, metadata)?;
return self.run_postprocessing(metadata, raw_outputs);
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
if matches!(metadata.execution_template, ExecutionTemplate::Gguf { .. })
&& metadata
.postprocessing
.iter()
.any(|s| matches!(s, PostprocessingStep::CodecDecode { .. }))
{
use super::strategies::{CodecTtsStrategy, ExecutionContext, ExecutionStrategy};
debug!(
target: "xybrid_core",
"Detected codec TTS metadata, dispatching to CodecTtsStrategy"
);
let strategy = CodecTtsStrategy::new();
let mut ctx = ExecutionContext {
base_path: &self.base_path,
runtimes: &mut self.runtimes,
};
return strategy.execute(&mut ctx, metadata, input);
}
let (runtime_type, model_file) = match &metadata.execution_template {
ExecutionTemplate::SafeTensors { model_file, .. } => ("candle", model_file.clone()),
ExecutionTemplate::Onnx { model_file } => ("onnx", model_file.clone()),
ExecutionTemplate::CoreMl { model_file } => ("coreml", model_file.clone()),
ExecutionTemplate::TfLite { model_file } => ("tflite", model_file.clone()),
ExecutionTemplate::ModelGraph { .. } => {
return Err(AdapterError::RuntimeError(
"ModelGraph execution should not reach single model path".to_string(),
));
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
ExecutionTemplate::Gguf {
model_file,
chat_template,
context_length,
..
} => {
debug!(
target: "xybrid_core",
"Detected GGUF template, routing to execute_llm()"
);
debug!(
target: "xybrid_core",
"GGUF model_file: {}, chat_template: {:?}, context_length: {}",
model_file,
chat_template,
context_length
);
let backend_hint = metadata.metadata.get("backend").and_then(|v| v.as_str());
return self.execute_llm(
metadata,
model_file,
chat_template.as_deref(),
*context_length,
input,
backend_hint,
config,
);
}
#[cfg(not(any(feature = "llm-mistral", feature = "llm-llamacpp")))]
ExecutionTemplate::Gguf { .. } => {
return Err(AdapterError::RuntimeError(
"GGUF/LLM execution requires the 'llm-mistral' or 'llm-llamacpp' feature"
.to_string(),
));
}
};
debug!(
target: "xybrid_core",
"Using {} runtime with model: {}",
runtime_type,
model_file
);
let model_full_path = Path::new(&self.base_path).join(&model_file);
let is_tts = Self::is_tts_model(metadata);
debug!(
target: "xybrid_core",
"Checking TTS: is_tts_model={}, preprocessing steps: {:?}",
is_tts,
metadata.preprocessing.iter().map(|s| s.step_name()).collect::<Vec<_>>()
);
if is_tts {
debug!(target: "xybrid_core", "TTS detected, calling execute_tts_chunked");
return self.execute_tts_chunked(metadata, input, &model_full_path);
}
let preprocessed = self.run_preprocessing(metadata, input)?;
let result_envelope = if preprocessed.is_token_ids() {
debug!(target: "xybrid_core", "Detected BERT-style inference (token IDs)");
let (ids, attention_mask, token_type_ids) = preprocessed
.as_token_ids()
.ok_or_else(|| AdapterError::InvalidInput("Expected token IDs".to_string()))?;
let session = OnnxSessionFactory::create_session(
&model_full_path,
ExecutionProviderKind::Cpu,
SessionOptions::default(),
)?;
let raw_outputs =
execute_bert_inference(&session, ids, attention_mask, token_type_ids)?;
crate::runtime_adapter::tensor_utils::tensors_to_envelope(
&raw_outputs,
session.output_names(),
)?
} else {
debug!(target: "xybrid_core", "Using standard execution path");
let runtime_input = preprocessed.to_envelope()?;
let runtime = self.runtimes.get_mut(runtime_type).ok_or_else(|| {
AdapterError::RuntimeError(format!("Runtime '{}' not configured", runtime_type))
})?;
debug!(target: "xybrid_core", "Loading model: {:?}", model_full_path);
runtime
.load(&model_full_path)
.map_err(|e| AdapterError::RuntimeError(format!("Load failed: {}", e)))?;
debug!(target: "xybrid_core", "Running inference");
runtime.execute(&runtime_input)?
};
let raw_outputs = RawOutputs::from_envelope(&result_envelope)?;
let result = self.run_postprocessing(metadata, raw_outputs)?;
info!(
target: "xybrid_core",
"Model execution complete: {} -> {}",
metadata.model_id,
result.kind_str()
);
Ok(result)
}
pub fn execute_with_context(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
context: &ConversationContext,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
let guard = ExecutionGuard::new_silent(&metadata.model_id, "execute_with_context");
let result = self.execute_with_context_impl(metadata, input, context, config);
if let Err(e) = &result {
mark_execution_terminal(&guard, e);
}
result
}
fn execute_with_context_impl(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
context: &ConversationContext,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
debug!(
target: "xybrid_core",
"TemplateExecutor.execute_with_context START: model_id={}, context_id={}",
metadata.model_id,
context.id()
);
if let Some(last) = context.history().last() {
if last.local_id() == input.local_id() {
warn!(
target: "xybrid_core",
"Input envelope was already pushed to context (local_id={}). \
This will cause the message to appear twice in the prompt. \
Push input to context AFTER execute_with_context, not before.",
input.local_id()
);
}
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
if let ExecutionTemplate::Gguf {
model_file,
context_length,
..
} = &metadata.execution_template
{
debug!(
target: "xybrid_core",
"LLM model detected, converting context to ChatMessages"
);
let mut chat_messages: Vec<ChatMessage> = Vec::new();
for envelope in context.context_for_llm() {
if let EnvelopeKind::Text(text) = &envelope.kind {
let role = envelope.role().unwrap_or(MessageRole::User);
chat_messages.push(ChatMessage {
role,
content: text.clone(),
});
}
}
if let EnvelopeKind::Text(text) = &input.kind {
let role = input.role().unwrap_or(MessageRole::User);
chat_messages.push(ChatMessage {
role,
content: text.clone(),
});
}
debug!(
target: "xybrid_core",
"Converted {} messages for LLM",
chat_messages.len()
);
let backend_hint = metadata.metadata.get("backend").and_then(|v| v.as_str());
let mut result = self.execute_llm_with_messages(
metadata,
model_file,
*context_length,
&chat_messages,
backend_hint,
config,
)?;
result = result.with_role(MessageRole::Assistant);
return Ok(result);
}
debug!(
target: "xybrid_core",
"Non-LLM model, executing without context transformation"
);
let mut result = self.execute_impl(metadata, input, config)?;
result = result.with_role(MessageRole::Assistant);
Ok(result)
}
pub fn execute_streaming(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
on_token: StreamingCallback<'_>,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
let guard = ExecutionGuard::new_silent(&metadata.model_id, "execute_streaming");
let result = self.execute_streaming_impl(metadata, input, on_token, config);
if let Err(e) = &result {
mark_execution_terminal(&guard, e);
}
result
}
fn execute_streaming_impl(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
#[allow(unused_variables)] on_token: StreamingCallback<'_>,
#[allow(unused_variables)] config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
{
if let super::template::ExecutionTemplate::Gguf {
model_file,
chat_template,
context_length,
..
} = &metadata.execution_template
{
let backend_hint = metadata.metadata.get("backend").and_then(|v| v.as_str());
return self.execute_llm_streaming(
metadata,
model_file,
chat_template.as_deref(),
*context_length,
input,
backend_hint,
on_token,
config,
);
}
debug!(
target: "xybrid_core",
"execute_streaming: Non-LLM model, falling back to regular execute()"
);
}
#[cfg(not(any(feature = "llm-mistral", feature = "llm-llamacpp")))]
{
debug!(
target: "xybrid_core",
"execute_streaming: LLM features not enabled, falling back to regular execute()"
);
}
self.execute_impl(metadata, input, config)
}
pub fn execute_streaming_with_context(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
context: &ConversationContext,
on_token: StreamingCallback<'_>,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
let guard =
ExecutionGuard::new_silent(&metadata.model_id, "execute_streaming_with_context");
let result =
self.execute_streaming_with_context_impl(metadata, input, context, on_token, config);
if let Err(e) = &result {
mark_execution_terminal(&guard, e);
}
result
}
#[allow(unused_variables)]
fn execute_streaming_with_context_impl(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
context: &ConversationContext,
on_token: StreamingCallback<'_>,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
debug!(
target: "xybrid_core",
"TemplateExecutor.execute_streaming_with_context START: model_id={}, context_id={}",
metadata.model_id,
context.id()
);
if let Some(last) = context.history().last() {
if last.local_id() == input.local_id() {
warn!(
target: "xybrid_core",
"Input envelope was already pushed to context (local_id={}). \
This will cause the message to appear twice in the prompt. \
Push input to context AFTER execute_streaming_with_context, not before.",
input.local_id()
);
}
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
{
if let ExecutionTemplate::Gguf {
model_file,
context_length,
..
} = &metadata.execution_template
{
debug!(
target: "xybrid_core",
"LLM model detected, converting context to ChatMessages for streaming"
);
let mut chat_messages: Vec<ChatMessage> = Vec::new();
for envelope in context.context_for_llm() {
if let EnvelopeKind::Text(text) = &envelope.kind {
let role = envelope.role().unwrap_or(MessageRole::User);
chat_messages.push(ChatMessage {
role,
content: text.clone(),
});
}
}
if let EnvelopeKind::Text(text) = &input.kind {
let role = input.role().unwrap_or(MessageRole::User);
chat_messages.push(ChatMessage {
role,
content: text.clone(),
});
}
debug!(
target: "xybrid_core",
"Converted {} messages for LLM",
chat_messages.len()
);
let backend_hint = metadata.metadata.get("backend").and_then(|v| v.as_str());
let result = self.execute_llm_streaming_with_messages(
metadata,
model_file,
*context_length,
&chat_messages,
backend_hint,
on_token,
config,
)?;
let result = result.with_role(MessageRole::Assistant);
return Ok(result);
}
debug!(
target: "xybrid_core",
"Non-LLM model, executing streaming without context transformation"
);
let mut result = self.execute_streaming_impl(metadata, input, on_token, config)?;
result = result.with_role(MessageRole::Assistant);
Ok(result)
}
#[cfg(not(any(feature = "llm-mistral", feature = "llm-llamacpp")))]
{
debug!(
target: "xybrid_core",
"execute_streaming_with_context: LLM features not enabled, using execute_with_context()"
);
self.execute_with_context_impl(metadata, input, context, config)
}
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
fn execute_llm_streaming(
&mut self,
metadata: &ModelMetadata,
model_file: &str,
chat_template: Option<&str>,
context_length: usize,
input: &Envelope,
backend_hint: Option<&str>,
on_token: StreamingCallback<'_>,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
info!(
target: "xybrid_core",
"Executing LLM inference with streaming: {} (backend: {:?})",
model_file,
backend_hint.unwrap_or("default")
);
let _llm_span = xybrid_trace::SpanGuard::new("llm_inference_streaming");
xybrid_trace::add_metadata("model", model_file);
xybrid_trace::add_metadata("streaming", "true");
stamp_llm_span_cost_attribution(metadata);
let model_path = Path::new(&self.base_path).join(model_file);
let model_path_str = model_path.to_string_lossy().to_string();
let need_load = match &self.llm_adapter_cache {
Some((cached_path, _)) if cached_path == &model_path_str => false,
_ => true,
};
if need_load {
let mut config =
LlmConfig::new(model_path_str.clone()).with_context_length(context_length);
if let Some(template) = chat_template {
let template_path = Path::new(&self.base_path).join(template);
config = config.with_chat_template(template_path.to_string_lossy().to_string());
}
let mut adapter = LlmRuntimeAdapter::with_backend_hint(backend_hint)?;
adapter.load_model_with_config(&config)?;
self.llm_adapter_cache = Some((model_path_str.clone(), adapter));
}
let prompt = match &input.kind {
EnvelopeKind::Text(text) => text.clone(),
_ => {
return Err(AdapterError::InvalidInput(
"LLM streaming requires text input".to_string(),
))
}
};
let system_prompt = input.metadata.get("system_prompt").map(|s| s.as_str());
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(ChatMessage::system(sys));
}
messages.push(ChatMessage::user(&prompt));
let gen_config = if let Some(cfg) = config {
cfg.clone()
} else {
let mut cfg = GenerationConfig::default();
if let Some(max_tokens) = input
.metadata
.get("max_tokens")
.and_then(|s| s.parse().ok())
{
cfg.max_tokens = max_tokens;
}
if let Some(temperature) = input
.metadata
.get("temperature")
.and_then(|s| s.parse().ok())
{
cfg.temperature = temperature;
}
cfg
};
let (output, backend_name, cached_prefix) =
if let Some((_, adapter)) = &self.llm_adapter_cache {
stamp_llm_runtime_backend(adapter);
let backend = adapter.backend();
let out = backend.generate_streaming(&messages, &gen_config, on_token)?;
let name = backend.name().to_string();
let cached = backend.last_cached_prefix_len();
(out, name, cached)
} else {
return Err(AdapterError::RuntimeError(
"LLM adapter cache unexpectedly empty".to_string(),
));
};
let mut response_metadata = std::collections::HashMap::new();
response_metadata.insert(
"tokens_generated".to_string(),
output.tokens_generated.to_string(),
);
response_metadata.insert(
"generation_time_ms".to_string(),
output.generation_time_ms.to_string(),
);
response_metadata.insert(
"tokens_per_second".to_string(),
format!("{:.2}", output.tokens_per_second),
);
response_metadata.insert("finish_reason".to_string(), output.finish_reason.clone());
insert_llm_streaming_metrics(&mut response_metadata, &output);
mirror_llm_metrics_to_span(&output, &backend_name, cached_prefix);
Ok(Envelope {
kind: EnvelopeKind::Text(output.text),
metadata: response_metadata,
})
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
fn execute_llm_with_messages(
&mut self,
metadata: &ModelMetadata,
model_file: &str,
context_length: usize,
messages: &[ChatMessage],
backend_hint: Option<&str>,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
info!(
target: "xybrid_core",
"Executing LLM with {} ChatMessages: {} (backend: {:?})",
messages.len(),
model_file,
backend_hint.unwrap_or("default")
);
let _llm_span = xybrid_trace::SpanGuard::new("llm_inference_with_messages");
xybrid_trace::add_metadata("model", model_file);
xybrid_trace::add_metadata("message_count", messages.len().to_string());
stamp_llm_span_cost_attribution(metadata);
let model_path = Path::new(&self.base_path).join(model_file);
let model_path_str = model_path.to_string_lossy().to_string();
let need_load = match &self.llm_adapter_cache {
Some((cached_path, _)) if cached_path == &model_path_str => false,
_ => true,
};
if need_load {
let config = LlmConfig::new(model_path_str.clone()).with_context_length(context_length);
let mut adapter = LlmRuntimeAdapter::with_backend_hint(backend_hint)?;
adapter.load_model_with_config(&config)?;
self.llm_adapter_cache = Some((model_path_str.clone(), adapter));
}
let gen_config = config.cloned().unwrap_or_default();
let (output, backend_name, cached_prefix) =
if let Some((_, adapter)) = &self.llm_adapter_cache {
stamp_llm_runtime_backend(adapter);
let backend = adapter.backend();
let out = backend.generate(messages, &gen_config)?;
let name = backend.name().to_string();
let cached = backend.last_cached_prefix_len();
(out, name, cached)
} else {
return Err(AdapterError::RuntimeError(
"LLM adapter cache unexpectedly empty".to_string(),
));
};
let mut response_metadata = std::collections::HashMap::new();
response_metadata.insert(
"tokens_generated".to_string(),
output.tokens_generated.to_string(),
);
response_metadata.insert(
"generation_time_ms".to_string(),
output.generation_time_ms.to_string(),
);
response_metadata.insert(
"tokens_per_second".to_string(),
format!("{:.2}", output.tokens_per_second),
);
response_metadata.insert("finish_reason".to_string(), output.finish_reason.clone());
insert_llm_streaming_metrics(&mut response_metadata, &output);
mirror_llm_metrics_to_span(&output, &backend_name, cached_prefix);
Ok(Envelope {
kind: EnvelopeKind::Text(output.text),
metadata: response_metadata,
})
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
fn execute_llm_streaming_with_messages(
&mut self,
metadata: &ModelMetadata,
model_file: &str,
context_length: usize,
messages: &[ChatMessage],
backend_hint: Option<&str>,
on_token: StreamingCallback<'_>,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
info!(
target: "xybrid_core",
"Executing LLM streaming with {} ChatMessages: {} (backend: {:?})",
messages.len(),
model_file,
backend_hint.unwrap_or("default")
);
let _llm_span = xybrid_trace::SpanGuard::new("llm_inference_streaming_with_messages");
xybrid_trace::add_metadata("model", model_file);
xybrid_trace::add_metadata("message_count", messages.len().to_string());
stamp_llm_span_cost_attribution(metadata);
let model_path = Path::new(&self.base_path).join(model_file);
let model_path_str = model_path.to_string_lossy().to_string();
let need_load = match &self.llm_adapter_cache {
Some((cached_path, _)) if cached_path == &model_path_str => false,
_ => true,
};
if need_load {
let config = LlmConfig::new(model_path_str.clone()).with_context_length(context_length);
let mut adapter = LlmRuntimeAdapter::with_backend_hint(backend_hint)?;
adapter.load_model_with_config(&config)?;
self.llm_adapter_cache = Some((model_path_str.clone(), adapter));
}
let gen_config = config.cloned().unwrap_or_default();
let (output, backend_name, cached_prefix) =
if let Some((_, adapter)) = &self.llm_adapter_cache {
stamp_llm_runtime_backend(adapter);
let backend = adapter.backend();
let out = backend.generate_streaming(messages, &gen_config, on_token)?;
let name = backend.name().to_string();
let cached = backend.last_cached_prefix_len();
(out, name, cached)
} else {
return Err(AdapterError::RuntimeError(
"LLM adapter cache unexpectedly empty".to_string(),
));
};
let mut response_metadata = std::collections::HashMap::new();
response_metadata.insert(
"tokens_generated".to_string(),
output.tokens_generated.to_string(),
);
response_metadata.insert(
"generation_time_ms".to_string(),
output.generation_time_ms.to_string(),
);
response_metadata.insert(
"tokens_per_second".to_string(),
format!("{:.2}", output.tokens_per_second),
);
response_metadata.insert("finish_reason".to_string(), output.finish_reason.clone());
insert_llm_streaming_metrics(&mut response_metadata, &output);
mirror_llm_metrics_to_span(&output, &backend_name, cached_prefix);
Ok(Envelope {
kind: EnvelopeKind::Text(output.text),
metadata: response_metadata,
})
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
fn execute_llm(
&mut self,
metadata: &ModelMetadata,
model_file: &str,
chat_template: Option<&str>,
context_length: usize,
input: &Envelope,
backend_hint: Option<&str>,
config: Option<&GenerationConfig>,
) -> ExecutorResult<Envelope> {
info!(
target: "xybrid_core",
"Executing LLM inference: {} (backend: {:?})",
model_file,
backend_hint.unwrap_or("default")
);
let _llm_span = xybrid_trace::SpanGuard::new("llm_inference");
xybrid_trace::add_metadata("model", model_file);
stamp_llm_span_cost_attribution(metadata);
let model_path = Path::new(&self.base_path).join(model_file);
let model_path_str = model_path.to_string_lossy().to_string();
let need_load = match &self.llm_adapter_cache {
Some((cached_path, _)) if cached_path == &model_path_str => {
info!(target: "xybrid_core", "Reusing cached LLM adapter for: {}", model_path_str);
false
}
Some((cached_path, _)) => {
info!(
target: "xybrid_core",
"Model path changed ({} -> {}), loading new model",
cached_path,
model_path_str
);
true
}
None => {
info!(target: "xybrid_core", "No cached adapter, loading model: {}", model_path_str);
true
}
};
if need_load {
let mut config =
LlmConfig::new(model_path_str.clone()).with_context_length(context_length);
if let Some(template) = chat_template {
let template_path = Path::new(&self.base_path).join(template);
config = config.with_chat_template(template_path.to_string_lossy().to_string());
}
let mut adapter = LlmRuntimeAdapter::with_backend_hint(backend_hint)?;
adapter.load_model_with_config(&config)?;
self.llm_adapter_cache = Some((model_path_str.clone(), adapter));
}
let gen_config = if let Some(cfg) = config {
cfg.clone()
} else {
let mut cfg = GenerationConfig::default();
if let Some(max_tokens) = input
.metadata
.get("max_tokens")
.and_then(|s| s.parse().ok())
{
cfg.max_tokens = max_tokens;
}
if let Some(temperature) = input
.metadata
.get("temperature")
.and_then(|s| s.parse().ok())
{
cfg.temperature = temperature;
}
cfg
};
let prompt = match &input.kind {
EnvelopeKind::Text(text) => text.clone(),
_ => {
return Err(AdapterError::InvalidInput(
"LLM requires text input".to_string(),
))
}
};
let system_prompt = input.metadata.get("system_prompt").map(|s| s.as_str());
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(ChatMessage::system(sys));
}
messages.push(ChatMessage::user(&prompt));
let (output, backend_name, cached_prefix) =
if let Some((_, adapter)) = &self.llm_adapter_cache {
stamp_llm_runtime_backend(adapter);
let backend = adapter.backend();
let out = backend.generate(&messages, &gen_config)?;
let name = backend.name().to_string();
let cached = backend.last_cached_prefix_len();
(out, name, cached)
} else {
return Err(AdapterError::RuntimeError(
"LLM adapter cache unexpectedly empty".to_string(),
));
};
info!(
target: "xybrid_core",
"LLM inference complete"
);
let mut response_metadata = std::collections::HashMap::new();
response_metadata.insert(
"tokens_generated".to_string(),
output.tokens_generated.to_string(),
);
response_metadata.insert(
"generation_time_ms".to_string(),
output.generation_time_ms.to_string(),
);
response_metadata.insert(
"tokens_per_second".to_string(),
format!("{:.2}", output.tokens_per_second),
);
response_metadata.insert("finish_reason".to_string(), output.finish_reason.clone());
insert_llm_streaming_metrics(&mut response_metadata, &output);
mirror_llm_metrics_to_span(&output, &backend_name, cached_prefix);
Ok(Envelope {
kind: EnvelopeKind::Text(output.text),
metadata: response_metadata,
})
}
fn run_preprocessing(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
) -> ExecutorResult<PreprocessedData> {
if metadata.preprocessing.is_empty() {
debug!(target: "xybrid_core", "No preprocessing steps configured");
return PreprocessedData::from_envelope(input);
}
info!(
target: "xybrid_core",
"Running {} preprocessing step(s)",
metadata.preprocessing.len()
);
let _preprocess_span = xybrid_trace::SpanGuard::new("preprocessing");
xybrid_trace::add_metadata("steps", metadata.preprocessing.len().to_string());
let mut data = PreprocessedData::from_envelope(input)?;
for step in &metadata.preprocessing {
let step_name = step.step_name();
debug!(target: "xybrid_core", "Applying preprocessing: {}", step_name);
let _step_span = xybrid_trace::SpanGuard::new(format!("preprocessing:{}", step_name));
xybrid_trace::add_metadata(
"span_kind",
if step_name.eq_ignore_ascii_case("audiodecode") {
"io"
} else {
"cpu"
},
);
data = preprocessing::apply_preprocessing_step(step, data, input, &self.base_path)?;
}
debug!(target: "xybrid_core", "Preprocessing complete");
Ok(data)
}
fn execute_pipeline(
&mut self,
stages: &[PipelineStage],
config: &HashMap<String, serde_json::Value>,
initial_input: PreprocessedData,
_metadata: &ModelMetadata,
) -> ExecutorResult<RawOutputs> {
let mut stage_outputs: HashMap<String, HashMap<String, ArrayD<f32>>> = HashMap::new();
let mut current_data = initial_input;
for (idx, stage) in stages.iter().enumerate() {
debug!(
target: "xybrid_core",
"Executing pipeline stage {}/{}: {} ({:?})",
idx + 1,
stages.len(),
stage.name,
stage.execution_mode
);
match &stage.execution_mode {
ExecutionMode::SingleShot => {
let runtime = self.runtimes.get_mut("onnx").ok_or_else(|| {
AdapterError::RuntimeError("ONNX runtime not configured".to_string())
})?;
let outputs = execute_single_shot_stage(
stage,
¤t_data,
&stage_outputs,
runtime.as_mut(),
&self.base_path,
)?;
stage_outputs.insert(stage.name.clone(), outputs.clone());
if let Some(first_output) = outputs.values().next() {
current_data = PreprocessedData::Tensor(first_output.clone());
}
}
ExecutionMode::Autoregressive {
max_tokens,
start_token_id,
end_token_id,
repetition_penalty,
} => {
let session = self.get_or_load_session(&stage.model_file)?;
let token_ids = execute_autoregressive_stage(
stage,
&stage_outputs,
config,
*max_tokens,
*start_token_id,
*end_token_id,
*repetition_penalty,
session,
)?;
return Ok(RawOutputs::TokenIds(token_ids));
}
ExecutionMode::IterativeRefinement { num_steps, .. } => {
return Err(AdapterError::InvalidInput(format!(
"IterativeRefinement not yet implemented (needs {} steps)",
num_steps
)));
}
ExecutionMode::WhisperDecoder {
max_tokens,
start_token_id,
end_token_id,
language_token_id,
task_token_id,
no_timestamps_token_id,
suppress_tokens,
repetition_penalty,
} => {
let session = self.get_or_load_session(&stage.model_file)?;
let token_ids = execute_whisper_decoder_stage(
stage,
&stage_outputs,
config,
*max_tokens,
*start_token_id,
*end_token_id,
*language_token_id,
*task_token_id,
*no_timestamps_token_id,
suppress_tokens,
*repetition_penalty,
session,
)?;
return Ok(RawOutputs::TokenIds(token_ids));
}
}
}
if let Some((_, outputs)) = stage_outputs.iter().last() {
Ok(RawOutputs::TensorMap(outputs.clone()))
} else {
Err(AdapterError::InvalidInput(
"Pipeline produced no outputs".to_string(),
))
}
}
fn run_postprocessing(
&mut self,
metadata: &ModelMetadata,
outputs: RawOutputs,
) -> ExecutorResult<Envelope> {
if metadata.postprocessing.is_empty() {
debug!(target: "xybrid_core", "No postprocessing steps configured");
return outputs.to_envelope();
}
info!(
target: "xybrid_core",
"Running {} postprocessing step(s)",
metadata.postprocessing.len()
);
let _postprocess_span = xybrid_trace::SpanGuard::new("postprocessing");
xybrid_trace::add_metadata("steps", metadata.postprocessing.len().to_string());
let mut data = outputs;
for step in &metadata.postprocessing {
let step_name = step.step_name();
debug!(target: "xybrid_core", "Applying postprocessing: {}", step_name);
let _step_span = xybrid_trace::SpanGuard::new(format!("postprocessing:{}", step_name));
xybrid_trace::add_metadata(
"span_kind",
if step_name.eq_ignore_ascii_case("ttsaudioencode") {
"io"
} else {
"cpu"
},
);
data = postprocessing::apply_postprocessing_step(step, data, &self.base_path)?;
}
debug!(target: "xybrid_core", "Postprocessing complete");
data.to_envelope()
}
fn get_or_load_session(&mut self, model_file: &str) -> ExecutorResult<&ONNXSession> {
let model_full_path = Path::new(&self.base_path).join(model_file);
{
let runtime = self.runtimes.get_mut("onnx").ok_or_else(|| {
AdapterError::RuntimeError("ONNX runtime not configured".to_string())
})?;
runtime.load(&model_full_path).map_err(|e| {
AdapterError::RuntimeError(format!("Failed to load session: {}", e))
})?;
}
let runtime = self.runtimes.get("onnx").unwrap();
if let Some(onnx_rt) = runtime.as_any().downcast_ref::<OnnxRuntime>() {
let path_str = model_full_path.to_string_lossy();
onnx_rt.get_session(&path_str)
} else {
Err(AdapterError::RuntimeError(
"Runtime 'onnx' is not OnnxRuntime".to_string(),
))
}
}
pub fn resolve_file_path(&self, file: &str) -> String {
if self.base_path.is_empty() {
file.to_string()
} else {
Path::new(&self.base_path)
.join(file)
.to_string_lossy()
.to_string()
}
}
const BREAK_WORDS: &'static [&'static str] = &[
"and", "or", "but", "because", "if", "however", "which", "when", "where", "while",
"although", "since", "unless", "after", "before", "that",
];
fn chunk_text_for_tts(text: &str, max_chars: usize) -> Vec<String> {
if text.len() <= max_chars {
return vec![text.to_string()];
}
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let sentences: Vec<&str> = text.split_inclusive(['.', '!', '?']).collect();
for sentence in sentences {
let sentence = sentence.trim();
if sentence.is_empty() {
continue;
}
if sentence.len() > max_chars {
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
current_chunk = String::new();
}
let mut sub_chunks = Vec::new();
Self::center_break_split(sentence, max_chars, 0, &mut sub_chunks);
if let Some(last) = sub_chunks.pop() {
for sc in sub_chunks {
if !sc.is_empty() {
chunks.push(sc);
}
}
current_chunk = last;
}
} else if current_chunk.len() + sentence.len() + 1 > max_chars {
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
current_chunk = sentence.to_string();
} else {
if !current_chunk.is_empty() {
current_chunk.push(' ');
}
current_chunk.push_str(sentence);
}
}
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
Self::migrate_trailing_break_words(&mut chunks);
chunks
}
fn center_break_split(text: &str, max_chars: usize, depth: usize, out: &mut Vec<String>) {
const MAX_DEPTH: usize = 3;
let trimmed = text.trim();
if trimmed.is_empty() {
return;
}
if trimmed.len() <= max_chars || depth >= MAX_DEPTH {
out.push(trimmed.to_string());
return;
}
let center = trimmed.len() / 2;
if let Some(pos) = Self::find_nearest(trimmed, center, |i, _| {
trimmed.as_bytes().get(i) == Some(&b',')
}) {
let left = trimmed[..=pos].trim();
let right = trimmed[pos + 1..].trim();
Self::center_break_split(left, max_chars, depth + 1, out);
Self::center_break_split(right, max_chars, depth + 1, out);
return;
}
if let Some((word_start, word_len)) = Self::find_nearest_break_word(trimmed, center) {
let left = trimmed[..word_start].trim();
let right = trimmed[word_start + word_len..].trim();
let break_word = &trimmed[word_start..word_start + word_len];
Self::center_break_split(left, max_chars, depth + 1, out);
let right_with_word = format!("{} {}", break_word, right);
Self::center_break_split(right_with_word.trim(), max_chars, depth + 1, out);
return;
}
if let Some(pos) = Self::find_nearest(trimmed, center, |i, _| {
trimmed
.as_bytes()
.get(i)
.is_some_and(|b| b.is_ascii_whitespace())
}) {
let left = trimmed[..pos].trim();
let right = trimmed[pos + 1..].trim();
Self::center_break_split(left, max_chars, depth + 1, out);
Self::center_break_split(right, max_chars, depth + 1, out);
return;
}
out.push(trimmed.to_string());
}
fn find_nearest<F>(text: &str, center: usize, pred: F) -> Option<usize>
where
F: Fn(usize, char) -> bool,
{
let len = text.len();
for offset in 0..len {
let right = center + offset;
if right < len {
if let Some(ch) = text[right..].chars().next() {
if pred(right, ch) {
return Some(right);
}
}
}
if offset > 0 && offset <= center {
let left = center - offset;
if let Some(ch) = text[left..].chars().next() {
if pred(left, ch) {
return Some(left);
}
}
}
}
None
}
fn find_nearest_break_word(text: &str, center: usize) -> Option<(usize, usize)> {
let lower = text.to_lowercase();
let mut best: Option<(usize, usize, usize)> = None;
for word in Self::BREAK_WORDS {
let pattern = format!(" {} ", word);
let mut search_start = 0;
while let Some(pos) = lower[search_start..].find(&pattern) {
let abs_pos = search_start + pos + 1; let dist = abs_pos.abs_diff(center);
if best.is_none() || dist < best.unwrap().2 {
best = Some((abs_pos, word.len(), dist));
}
search_start = search_start + pos + 1;
}
}
best.map(|(start, len, _)| (start, len))
}
fn migrate_trailing_break_words(chunks: &mut [String]) {
let mut i = 0;
while i + 1 < chunks.len() {
let ends_with_break = Self::BREAK_WORDS.iter().any(|w| {
let chunk = &chunks[i];
let lower = chunk.to_lowercase();
lower.ends_with(&format!(" {}", w)) || lower == *w
});
if ends_with_break {
let chunk = chunks[i].clone();
if let Some(last_space) = chunk.rfind(' ') {
let word = &chunk[last_space + 1..];
let lower_word = word.to_lowercase();
if Self::BREAK_WORDS.contains(&lower_word.as_str()) {
chunks[i] = chunk[..last_space].trim().to_string();
chunks[i + 1] = format!("{} {}", word, chunks[i + 1]);
}
}
}
i += 1;
}
}
fn execute_tts_chunked(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
model_path: &Path,
) -> ExecutorResult<Envelope> {
use crate::ir::EnvelopeKind;
const DEFAULT_MAX_TTS_CHARS: usize = 350;
let max_tts_chars = metadata.max_chunk_chars.unwrap_or(DEFAULT_MAX_TTS_CHARS);
let text = match &input.kind {
EnvelopeKind::Text(t) => t.clone(),
_ => {
return Err(AdapterError::InvalidInput(
"TTS requires text input".to_string(),
))
}
};
debug!(
target: "xybrid_core",
"TTS Chunked: Input text length: {} chars (max_chunk_chars={})",
text.len(),
max_tts_chars
);
if text.len() <= max_tts_chars {
debug!(target: "xybrid_core", "TTS: Text is short enough, using single execution");
return self.execute_tts_single(metadata, input, model_path);
}
debug!(
target: "xybrid_core",
"TTS: Text too long ({} chars), splitting into chunks",
text.len()
);
let chunks = Self::chunk_text_for_tts(&text, max_tts_chars);
debug!(target: "xybrid_core", "TTS: Split into {} chunks", chunks.len());
const CROSSFADE_SAMPLES: usize = 480;
let mut audio_chunks: Vec<Vec<f32>> = Vec::new();
let session = OnnxSessionFactory::create_session(
model_path,
ExecutionProviderKind::Cpu,
SessionOptions::default(),
)?;
let speed = extract_tts_speed(input);
for (i, chunk) in chunks.iter().enumerate() {
debug!(target: "xybrid_core", "TTS: Processing chunk {}/{}: {} chars", i + 1, chunks.len(), chunk.len());
let chunk_input = Envelope {
kind: EnvelopeKind::Text(chunk.clone()),
metadata: input.metadata.clone(),
};
let preprocessed = self.run_preprocessing(metadata, &chunk_input)?;
let phoneme_ids = preprocessed
.as_phoneme_ids()
.ok_or_else(|| AdapterError::InvalidInput("Expected phoneme IDs".to_string()))?;
debug!(target: "xybrid_core", "TTS: Chunk {} has {} phoneme IDs", i + 1, phoneme_ids.len());
let voice_loader = TtsVoiceLoader::new(&self.base_path);
let voice_embedding = voice_loader.load(metadata, input)?;
let raw_outputs = execute_tts_inference(&session, phoneme_ids, voice_embedding, speed)?;
if let Some(audio_tensor) = raw_outputs.values().next() {
let mut chunk_audio: Vec<f32> = audio_tensor.iter().cloned().collect();
let trim_count = metadata.trim_trailing_samples.unwrap_or(0);
if trim_count > 0 && chunk_audio.len() > trim_count {
chunk_audio.truncate(chunk_audio.len() - trim_count);
}
audio_chunks.push(chunk_audio);
}
}
let all_audio = crossfade_audio_chunks(&audio_chunks, CROSSFADE_SAMPLES);
debug!(target: "xybrid_core", "TTS: Total audio samples: {}", all_audio.len());
let output_names = session.output_names();
let output_name = output_names.first().map(|s| s.as_str()).unwrap_or("audio");
let mut combined_outputs: HashMap<String, ArrayD<f32>> = HashMap::new();
let audio_array = ndarray::Array1::from_vec(all_audio).into_dyn();
combined_outputs.insert(output_name.to_string(), audio_array);
self.run_postprocessing(metadata, RawOutputs::TensorMap(combined_outputs))
}
fn execute_tts_single(
&mut self,
metadata: &ModelMetadata,
input: &Envelope,
model_path: &Path,
) -> ExecutorResult<Envelope> {
let preprocessed = self.run_preprocessing(metadata, input)?;
let phoneme_ids = preprocessed
.as_phoneme_ids()
.ok_or_else(|| AdapterError::InvalidInput("Expected phoneme IDs".to_string()))?;
debug!(
target: "xybrid_core",
"TTS Single: Input text length: {} chars, first 100: {:?}",
match &input.kind {
crate::ir::EnvelopeKind::Text(t) => t.len(),
_ => 0,
},
match &input.kind {
crate::ir::EnvelopeKind::Text(t) => t.chars().take(100).collect::<String>(),
_ => "(not text)".to_string(),
}
);
debug!(
target: "xybrid_core",
"TTS: Phoneme IDs count: {}, first 20: {:?}",
phoneme_ids.len(),
&phoneme_ids[..phoneme_ids.len().min(20)]
);
let voice_loader = TtsVoiceLoader::new(&self.base_path);
let voice_embedding = voice_loader.load(metadata, input)?;
let session = OnnxSessionFactory::create_session(
model_path,
ExecutionProviderKind::Cpu,
SessionOptions::default(),
)?;
let speed = extract_tts_speed(input);
let mut raw_outputs = execute_tts_inference(&session, phoneme_ids, voice_embedding, speed)?;
let trim_count = metadata.trim_trailing_samples.unwrap_or(0);
if trim_count > 0 {
for audio in raw_outputs.values_mut() {
let len = audio.len();
if len > trim_count {
audio.slice_collapse(ndarray::s![..len - trim_count]);
}
}
}
self.run_postprocessing(metadata, RawOutputs::TensorMap(raw_outputs))
}
fn is_tts_model(metadata: &ModelMetadata) -> bool {
use super::template::PreprocessingStep;
metadata
.preprocessing
.iter()
.any(|step| matches!(step, PreprocessingStep::Phonemize { .. }))
}
}
pub(crate) fn extract_tts_speed(envelope: &Envelope) -> f32 {
let speed = envelope
.metadata
.get("speed")
.and_then(|s| s.parse::<f32>().ok())
.unwrap_or(1.0);
if !(0.5..=2.0).contains(&speed) {
warn!(
"TTS speed {:.2} is outside valid range [0.5, 2.0], clamping",
speed
);
return speed.clamp(0.5, 2.0);
}
speed
}
fn crossfade_audio_chunks(chunks: &[Vec<f32>], crossfade_len: usize) -> Vec<f32> {
if chunks.is_empty() {
return Vec::new();
}
if chunks.len() == 1 {
return chunks[0].clone();
}
let mut result = chunks[0].clone();
for chunk in &chunks[1..] {
if result.len() < 2 * crossfade_len || chunk.len() < 2 * crossfade_len {
result.extend(chunk);
continue;
}
let overlap_start = result.len() - crossfade_len;
for i in 0..crossfade_len {
let t = (i + 1) as f32 / (crossfade_len + 1) as f32;
let fade_out = 1.0 - t;
let fade_in = t;
result[overlap_start + i] = result[overlap_start + i] * fade_out + chunk[i] * fade_in;
}
result.extend_from_slice(&chunk[crossfade_len..]);
}
result
}
impl Default for TemplateExecutor {
fn default() -> Self {
Self::new("")
}
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
fn insert_llm_streaming_metrics(
response_metadata: &mut HashMap<String, String>,
output: &crate::runtime_adapter::llm::GenerationOutput,
) {
if let Some(v) = output.ttft_ms {
response_metadata.insert("ttft_ms".to_string(), v.to_string());
}
if let Some(v) = output.mean_itl_ms {
response_metadata.insert("mean_itl_ms".to_string(), format!("{:.4}", v));
}
if let Some(v) = output.p95_itl_ms {
response_metadata.insert("p95_itl_ms".to_string(), v.to_string());
}
if let Some(v) = output.emitted_chunks {
response_metadata.insert("emitted_chunks".to_string(), v.to_string());
}
if let Some(v) = output.decode_tps {
response_metadata.insert("decode_tps".to_string(), format!("{:.4}", v));
}
if let Some(v) = output.prefill_tps {
response_metadata.insert("prefill_tps".to_string(), format!("{:.4}", v));
}
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
fn mirror_llm_metrics_to_span(
output: &crate::runtime_adapter::llm::GenerationOutput,
backend_name: &str,
cached_prefix_tokens: Option<usize>,
) {
xybrid_trace::add_metadata("tokens_generated", output.tokens_generated.to_string());
xybrid_trace::add_metadata("generation_time_ms", output.generation_time_ms.to_string());
xybrid_trace::add_metadata(
"tokens_per_second",
format!("{:.2}", output.tokens_per_second),
);
xybrid_trace::add_metadata("finish_reason", &output.finish_reason);
xybrid_trace::add_metadata(
"execution_provider",
crate::runtime_adapter::llm::local_execution_provider(backend_name),
);
if let Some(n) = cached_prefix_tokens {
if n > 0 {
xybrid_trace::add_metadata("prompt_cached_tokens", n.to_string());
}
}
if let Some(v) = output.ttft_ms {
xybrid_trace::add_metadata("ttft_ms", v.to_string());
}
if let Some(v) = output.mean_itl_ms {
xybrid_trace::add_metadata("mean_itl_ms", format!("{:.4}", v));
}
if let Some(v) = output.p95_itl_ms {
xybrid_trace::add_metadata("p95_itl_ms", v.to_string());
}
if let Some(v) = output.emitted_chunks {
xybrid_trace::add_metadata("emitted_chunks", v.to_string());
}
if let Some(v) = output.decode_tps {
xybrid_trace::add_metadata("decode_tps", format!("{:.4}", v));
}
if let Some(v) = output.prefill_tps {
xybrid_trace::add_metadata("prefill_tps", format!("{:.4}", v));
}
}
#[cfg(test)]
mod tests {
use super::super::template::PreprocessingStep;
use super::*;
use crate::ir::EnvelopeKind;
#[test]
fn test_executor_creation() {
let executor = TemplateExecutor::default();
assert_eq!(executor.base_path, "");
}
#[test]
fn test_executor_with_base_path() {
let executor = TemplateExecutor::with_base_path("/path/to/models");
assert_eq!(executor.base_path, "/path/to/models");
}
#[test]
fn test_resolve_file_path() {
let executor = TemplateExecutor::with_base_path("/models");
let resolved = executor.resolve_file_path("encoder.onnx");
assert!(resolved.contains("encoder.onnx"));
}
#[test]
fn test_resolve_file_path_empty_base() {
let executor = TemplateExecutor::with_base_path("");
let resolved = executor.resolve_file_path("encoder.onnx");
assert_eq!(resolved, "encoder.onnx");
}
#[test]
fn test_default_runtimes_contains_onnx() {
let runtimes = TemplateExecutor::default_runtimes();
assert!(runtimes.contains_key("onnx"));
}
#[test]
fn test_with_runtimes_custom_injection() {
let runtimes: HashMap<String, Box<dyn ModelRuntime>> = HashMap::new();
let executor = TemplateExecutor::with_runtimes("/test", runtimes);
assert_eq!(executor.base_path, "/test");
assert!(executor.list_runtimes().is_empty());
}
#[test]
fn test_register_runtime() {
let mut executor = TemplateExecutor::with_runtimes("/test", HashMap::new());
assert!(executor.list_runtimes().is_empty());
executor.register_runtime("onnx", Box::new(OnnxRuntime::new()));
assert!(executor.list_runtimes().contains(&"onnx"));
assert!(executor.get_runtime("onnx").is_some());
}
#[test]
fn test_list_runtimes() {
let executor = TemplateExecutor::new("/test");
let runtimes = executor.list_runtimes();
assert!(runtimes.contains(&"onnx"));
}
#[test]
fn test_get_runtime_not_found() {
let executor = TemplateExecutor::new("/test");
assert!(executor.get_runtime("nonexistent").is_none());
}
#[test]
fn test_chunk_text_short_input_unchanged() {
let text = "Hello world, this is a short sentence that is well under the limit.";
let chunks = TemplateExecutor::chunk_text_for_tts(text, 350);
assert_eq!(chunks, vec![text]);
}
#[test]
fn test_chunk_text_exactly_at_limit() {
let text = "A".repeat(350);
let chunks = TemplateExecutor::chunk_text_for_tts(&text, 350);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].len(), 350);
}
#[test]
fn test_chunk_text_splits_at_sentence_boundaries() {
let text = "First sentence. Second sentence. Third sentence.";
let chunks = TemplateExecutor::chunk_text_for_tts(text, 20);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0], "First sentence.");
assert_eq!(chunks[1], "Second sentence.");
assert_eq!(chunks[2], "Third sentence.");
}
#[test]
fn test_chunk_text_combines_short_sentences() {
let text = "Hi. Hello. Hey there.";
let chunks = TemplateExecutor::chunk_text_for_tts(text, 50);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], "Hi. Hello. Hey there.");
}
#[test]
fn test_chunk_text_handles_exclamation_and_question() {
let text = "What? Really! Yes.";
let chunks = TemplateExecutor::chunk_text_for_tts(text, 10);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0], "What?");
assert_eq!(chunks[1], "Really!");
assert_eq!(chunks[2], "Yes.");
}
#[test]
fn test_chunk_text_center_break_comma() {
let left = "The quick brown fox jumped over the lazy dog and ran across the wide green meadow towards the old wooden fence in the distance while the birds sang their morning songs above the tall oak trees lining the path";
let right = " and then it stopped to rest under the shade of a willow tree by the river where the water flowed gently over smooth stones and the fish swam lazily in the warm afternoon sun as clouds drifted slowly overhead";
let text = format!("{},{}", left, right);
assert!(
text.len() > 350,
"Test text must exceed 350 chars, got {}",
text.len()
);
let chunks = TemplateExecutor::chunk_text_for_tts(&text, 350);
assert!(chunks.len() >= 2, "Should split into at least 2 chunks");
assert!(
chunks[0].ends_with(','),
"First chunk should end at comma, got: '{}'",
chunks[0]
);
}
#[test]
fn test_chunk_text_center_break_word() {
let text = "The quick brown fox jumped over the lazy dog running across the wide green meadow towards the old wooden fence in the distance while birds sang their morning songs above the tall oak trees however the gentle breeze carried the sweet scent of wildflowers across the rolling hills and through the valleys where deer grazed peacefully in the golden light of the setting sun painting everything in warm hues";
assert!(
text.len() > 350,
"Test text must exceed 350 chars, got {}",
text.len()
);
assert!(!text.contains(','), "Test text should have no commas");
let chunks = TemplateExecutor::chunk_text_for_tts(text, 350);
assert!(chunks.len() >= 2, "Should split into at least 2 chunks");
let second_lower = chunks[1].to_lowercase();
let starts_with_break = TemplateExecutor::BREAK_WORDS
.iter()
.any(|w| second_lower.starts_with(w));
assert!(
starts_with_break,
"Second chunk should start with a break word after post-pass, got: '{}'",
&chunks[1][..chunks[1].len().min(40)]
);
}
#[test]
fn test_chunk_text_center_break_whitespace() {
let text = "aaaa bbbb cccc dddd eeee ffff gggg hhhh iiii jjjj kkkk llll mmmm nnnn oooo pppp qqqq rrrr ssss tttt uuuu vvvv wwww xxxx yyyy zzzz aaaa bbbb cccc dddd eeee ffff gggg hhhh iiii jjjj kkkk llll mmmm nnnn oooo pppp qqqq rrrr ssss tttt uuuu vvvv wwww xxxx yyyy zzzz aaaa bbbb cccc dddd eeee ffff gggg hhhh iiii jjjj kkkk llll mmmm nnnn oooo pppp qqqq rrrr ssss tttt uuuu";
assert!(
text.len() > 350,
"Test text must exceed 350 chars, got {}",
text.len()
);
assert!(!text.contains(','), "No commas");
let chunks = TemplateExecutor::chunk_text_for_tts(text, 350);
assert!(chunks.len() >= 2, "Should split into at least 2 chunks");
for chunk in &chunks {
assert!(!chunk.is_empty(), "Chunks should not be empty");
assert_eq!(chunk.as_str(), chunk.trim(), "Chunks should be trimmed");
}
}
#[test]
fn test_chunk_text_multi_sentence_long_first() {
let long_sentence = "The magnificent cathedral stood tall against the stormy sky its ancient stone walls bearing witness to centuries of history while gargoyles perched on every corner watched over the bustling city below where merchants sold their wares in the cobblestone market square filled with the aroma of freshly baked bread and exotic spices brought by traders from distant lands across vast oceans and treacherous mountain passes.";
let short_sentence = " A bird sang nearby.";
let text = format!("{}{}", long_sentence, short_sentence);
assert!(
long_sentence.len() > 350,
"First sentence must exceed 350 chars"
);
let chunks = TemplateExecutor::chunk_text_for_tts(&text, 350);
assert!(chunks.len() >= 2, "Should split into at least 2 chunks");
let last = chunks.last().unwrap();
assert!(
last.contains("A bird sang nearby"),
"Short sentence should be intact in last chunk, got: '{}'",
last
);
}
#[test]
fn test_chunk_text_empty_input() {
let chunks = TemplateExecutor::chunk_text_for_tts("", 350);
assert!(chunks.is_empty() || chunks == vec![""]);
}
#[test]
fn test_chunk_text_whitespace_only() {
let chunks = TemplateExecutor::chunk_text_for_tts(" ", 350);
assert!(chunks.is_empty() || chunks.iter().all(|c| c.trim().is_empty()));
}
#[test]
fn test_chunk_text_preserves_content() {
let text =
"The quick brown fox jumps over the lazy dog. Pack my box with five dozen liquor jugs.";
let chunks = TemplateExecutor::chunk_text_for_tts(text, 50);
let rejoined: String = chunks.join(" ");
assert!(rejoined.contains("quick"));
assert!(rejoined.contains("fox"));
assert!(rejoined.contains("liquor"));
}
#[test]
fn test_chunk_text_post_pass_break_word_migration() {
let text =
"The fox ran fast and the dog chased it quickly through the woods and over the hill.";
let chunks = TemplateExecutor::chunk_text_for_tts(text, 30);
for (i, chunk) in chunks.iter().enumerate() {
if i + 1 < chunks.len() {
let lower = chunk.to_lowercase();
for w in TemplateExecutor::BREAK_WORDS {
assert!(
!lower.ends_with(&format!(" {}", w)),
"Chunk {} should not end with break word '{}': '{}'",
i,
w,
chunk
);
}
}
}
}
#[test]
fn test_is_tts_model_with_phonemize_step() {
let metadata = ModelMetadata::onnx("test-tts", "1.0", "model.onnx").with_preprocessing(
PreprocessingStep::Phonemize {
tokens_file: "tokens.txt".to_string(),
backend: Default::default(),
dict_file: None,
language: None,
add_padding: true,
normalize_text: false,
silence_tokens: None,
},
);
assert!(TemplateExecutor::is_tts_model(&metadata));
}
#[test]
fn test_is_tts_model_without_phonemize() {
let metadata = ModelMetadata::onnx("test-asr", "1.0", "model.onnx").with_preprocessing(
PreprocessingStep::AudioDecode {
sample_rate: 16000,
channels: 1,
},
);
assert!(!TemplateExecutor::is_tts_model(&metadata));
}
#[test]
fn test_is_tts_model_no_preprocessing() {
let metadata = ModelMetadata::onnx("test-model", "1.0", "model.onnx");
assert!(!TemplateExecutor::is_tts_model(&metadata));
}
#[test]
fn test_is_tts_model_phonemize_among_other_steps() {
let metadata = ModelMetadata::onnx("test-tts", "1.0", "model.onnx")
.with_preprocessing(PreprocessingStep::Normalize {
mean: vec![0.0],
std: vec![1.0],
})
.with_preprocessing(PreprocessingStep::Phonemize {
tokens_file: "tokens.txt".to_string(),
backend: Default::default(),
dict_file: None,
language: None,
add_padding: true,
normalize_text: false,
silence_tokens: None,
});
assert!(TemplateExecutor::is_tts_model(&metadata));
}
#[test]
fn test_is_tts_model_with_mel_spectrogram_is_not_tts() {
let metadata = ModelMetadata::onnx("test-asr", "1.0", "model.onnx").with_preprocessing(
PreprocessingStep::MelSpectrogram {
preset: Some("whisper".to_string()),
n_mels: 80,
sample_rate: 16000,
fft_size: 400,
hop_length: 160,
mel_scale: Default::default(),
max_frames: Some(3000),
},
);
assert!(!TemplateExecutor::is_tts_model(&metadata));
}
#[test]
fn test_execute_with_context_builds_message_list() {
use crate::conversation::ConversationContext;
let mut ctx = ConversationContext::new().with_system(
Envelope::new(EnvelopeKind::Text("You are helpful.".to_string()))
.with_role(MessageRole::System),
);
ctx.push(
Envelope::new(EnvelopeKind::Text("Hello!".to_string())).with_role(MessageRole::User),
);
ctx.push(
Envelope::new(EnvelopeKind::Text("Hi there!".to_string()))
.with_role(MessageRole::Assistant),
);
let messages = ctx.context_for_llm();
assert_eq!(messages.len(), 3);
assert!(messages[0].is_system_message());
assert!(messages[1].is_user_message());
assert!(messages[2].is_assistant_message());
let input = Envelope::new(EnvelopeKind::Text("How are you?".to_string()))
.with_role(MessageRole::User);
let mut all_messages = messages.clone();
all_messages.push(&input);
assert_eq!(all_messages.len(), 4);
}
#[test]
fn test_execute_with_context_uses_chat_template_formatter() {
use super::super::chat_template::{ChatTemplateFormat, ChatTemplateFormatter};
use crate::conversation::ConversationContext;
let mut ctx = ConversationContext::new().with_system(
Envelope::new(EnvelopeKind::Text("You are helpful.".to_string()))
.with_role(MessageRole::System),
);
ctx.push(
Envelope::new(EnvelopeKind::Text("Hello!".to_string())).with_role(MessageRole::User),
);
ctx.push(
Envelope::new(EnvelopeKind::Text("Hi there!".to_string()))
.with_role(MessageRole::Assistant),
);
let input = Envelope::new(EnvelopeKind::Text("How are you?".to_string()))
.with_role(MessageRole::User);
let mut messages: Vec<&Envelope> = ctx.context_for_llm();
messages.push(&input);
let prompt = ChatTemplateFormatter::format(&messages, ChatTemplateFormat::ChatML);
assert!(prompt.contains("<|im_start|>system\nYou are helpful.<|im_end|>"));
assert!(prompt.contains("<|im_start|>user\nHello!<|im_end|>"));
assert!(prompt.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
assert!(prompt.contains("<|im_start|>user\nHow are you?<|im_end|>"));
assert!(prompt.ends_with("<|im_start|>assistant\n"));
}
#[test]
fn test_execute_with_context_result_tagged_as_assistant() {
let envelope = Envelope::new(EnvelopeKind::Text("I'm doing great!".to_string()));
assert!(envelope.role().is_none());
let tagged = envelope.with_role(MessageRole::Assistant);
assert!(tagged.is_assistant_message());
assert_eq!(tagged.role(), Some(MessageRole::Assistant));
}
#[test]
fn test_execute_with_context_preserves_input_content() {
use super::super::chat_template::{ChatTemplateFormat, ChatTemplateFormatter};
use crate::conversation::ConversationContext;
let ctx = ConversationContext::new();
let input = Envelope::new(EnvelopeKind::Text("What is 2+2?".to_string()))
.with_role(MessageRole::User);
let mut messages: Vec<&Envelope> = ctx.context_for_llm();
messages.push(&input);
let prompt = ChatTemplateFormatter::format(&messages, ChatTemplateFormat::ChatML);
assert!(prompt.contains("What is 2+2?"));
}
#[test]
fn test_execute_with_context_with_empty_context() {
use super::super::chat_template::{ChatTemplateFormat, ChatTemplateFormatter};
use crate::conversation::ConversationContext;
let ctx = ConversationContext::new();
let input =
Envelope::new(EnvelopeKind::Text("Hello!".to_string())).with_role(MessageRole::User);
let mut messages: Vec<&Envelope> = ctx.context_for_llm();
messages.push(&input);
let prompt = ChatTemplateFormatter::format(&messages, ChatTemplateFormat::ChatML);
assert_eq!(
prompt,
"<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n"
);
}
#[test]
fn test_execute_with_context_llama_format() {
use super::super::chat_template::{ChatTemplateFormat, ChatTemplateFormatter};
use crate::conversation::ConversationContext;
let mut ctx = ConversationContext::new().with_system(
Envelope::new(EnvelopeKind::Text("Be concise.".to_string()))
.with_role(MessageRole::System),
);
ctx.push(Envelope::new(EnvelopeKind::Text("Hi!".to_string())).with_role(MessageRole::User));
ctx.push(
Envelope::new(EnvelopeKind::Text("Hello!".to_string()))
.with_role(MessageRole::Assistant),
);
let input =
Envelope::new(EnvelopeKind::Text("Bye!".to_string())).with_role(MessageRole::User);
let mut messages: Vec<&Envelope> = ctx.context_for_llm();
messages.push(&input);
let prompt = ChatTemplateFormatter::format(&messages, ChatTemplateFormat::Llama);
assert!(prompt.contains("<<SYS>>"));
assert!(prompt.contains("Be concise."));
assert!(prompt.contains("[INST]"));
assert!(prompt.contains("[/INST]"));
}
#[test]
fn test_chat_template_format_from_str() {
use super::super::chat_template::ChatTemplateFormat;
assert_eq!(
ChatTemplateFormat::from_str("chatml"),
Some(ChatTemplateFormat::ChatML)
);
assert_eq!(
ChatTemplateFormat::from_str("llama"),
Some(ChatTemplateFormat::Llama)
);
assert_eq!(
ChatTemplateFormat::from_str("llama2"),
Some(ChatTemplateFormat::Llama)
);
assert_eq!(ChatTemplateFormat::from_str("unknown"), None);
let default: ChatTemplateFormat = Default::default();
assert_eq!(default, ChatTemplateFormat::ChatML);
}
#[test]
fn test_crossfade_empty_chunks() {
let chunks: Vec<Vec<f32>> = vec![];
let result = crossfade_audio_chunks(&chunks, 480);
assert!(result.is_empty());
}
#[test]
fn test_crossfade_single_chunk_unchanged() {
let chunk = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = crossfade_audio_chunks(std::slice::from_ref(&chunk), 480);
assert_eq!(result, chunk);
}
#[test]
fn test_crossfade_two_chunks() {
let crossfade_len = 4;
let chunk_a = vec![1.0; 10];
let chunk_b = vec![0.0; 10];
let result = crossfade_audio_chunks(&[chunk_a, chunk_b], crossfade_len);
assert_eq!(result.len(), 16);
for &v in &result[..6] {
assert!((v - 1.0).abs() < 1e-6);
}
for i in 0..crossfade_len {
let t = (i + 1) as f32 / (crossfade_len + 1) as f32;
let expected = 1.0 - t;
assert!(
(result[6 + i] - expected).abs() < 1e-6,
"at overlap index {i}: got {}, expected {expected}",
result[6 + i]
);
}
for &v in &result[10..] {
assert!((v - 0.0).abs() < 1e-6);
}
}
#[test]
fn test_crossfade_three_chunks() {
let crossfade_len = 2;
let chunk_a = vec![1.0; 8];
let chunk_b = vec![0.5; 8];
let chunk_c = vec![0.0; 8];
let result = crossfade_audio_chunks(&[chunk_a, chunk_b, chunk_c], crossfade_len);
assert_eq!(result.len(), 20);
}
#[test]
fn test_crossfade_short_chunk_skips_crossfade() {
let crossfade_len = 4;
let chunk_a = vec![1.0; 10];
let chunk_b = vec![0.5; 6];
let result = crossfade_audio_chunks(&[chunk_a, chunk_b], crossfade_len);
assert_eq!(result.len(), 16);
assert!((result[9] - 1.0).abs() < 1e-6);
assert!((result[10] - 0.5).abs() < 1e-6);
}
#[test]
fn test_crossfade_preserves_total_energy() {
let crossfade_len = 4;
let chunk_a = vec![0.5; 10];
let chunk_b = vec![0.5; 10];
let result = crossfade_audio_chunks(&[chunk_a, chunk_b], crossfade_len);
for &v in &result {
assert!(
(v - 0.5).abs() < 1e-6,
"expected 0.5, got {v} — crossfade should preserve constant signal"
);
}
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
mod stamp_llm_span {
use super::*;
use crate::tracing;
use std::sync::Mutex;
static GLOBAL_TRACE_LOCK: Mutex<()> = Mutex::new(());
fn gguf_metadata(backend_hint: Option<&str>) -> ModelMetadata {
let mut bundle_metadata = HashMap::new();
if let Some(hint) = backend_hint {
bundle_metadata.insert("backend".to_string(), serde_json::json!(hint));
}
ModelMetadata {
model_id: "test-gguf".into(),
version: "1".into(),
execution_template: ExecutionTemplate::Gguf {
model_file: "test.gguf".into(),
chat_template: None,
context_length: 2048,
generation_params: None,
},
preprocessing: Vec::new(),
postprocessing: Vec::new(),
files: Vec::new(),
description: None,
metadata: bundle_metadata,
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
}
}
fn capture_span_metadata(
span_name: &str,
metadata: &ModelMetadata,
) -> HashMap<String, String> {
tracing::init_tracing(true);
tracing::reset_tracing();
{
let _guard = tracing::SpanGuard::new(span_name);
stamp_llm_span_cost_attribution(metadata);
}
let json = tracing::get_stages_json();
tracing::reset_tracing();
let span = json["spans"]
.as_array()
.and_then(|spans| spans.iter().find(|s| s["name"].as_str() == Some(span_name)))
.expect("span recorded by SpanGuard must be present in stages json");
span["metadata"]
.as_object()
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_owned())))
.collect()
})
.unwrap_or_default()
}
#[test]
fn unannotated_gguf_stamps_llamacpp_default() {
let _lock = GLOBAL_TRACE_LOCK.lock().unwrap();
let captured = capture_span_metadata("execute:test", &gguf_metadata(None));
assert_eq!(
captured.get("backend").map(String::as_str),
Some("llamacpp"),
"chat-context flow must default unannotated GGUF bundles to llamacpp so PlatformEvent.backend is non-empty"
);
}
#[test]
fn mistralrs_hint_wins_on_gguf() {
let _lock = GLOBAL_TRACE_LOCK.lock().unwrap();
let captured = capture_span_metadata("execute:test", &gguf_metadata(Some("mistralrs")));
assert_eq!(
captured.get("backend").map(String::as_str),
Some("mistralrs")
);
}
#[test]
fn legacy_mistral_alias_normalises_to_mistralrs() {
let _lock = GLOBAL_TRACE_LOCK.lock().unwrap();
let captured = capture_span_metadata("execute:test", &gguf_metadata(Some("mistral")));
assert_eq!(
captured.get("backend").map(String::as_str),
Some("mistralrs"),
"the legacy `mistral` bundle alias must canonicalise to the wire label"
);
}
#[test]
fn quantization_stamped_from_gguf_filename() {
let _lock = GLOBAL_TRACE_LOCK.lock().unwrap();
let mut metadata = gguf_metadata(None);
metadata.execution_template = ExecutionTemplate::Gguf {
model_file: "tinyllama-1.1b-chat-q4_k_m.gguf".into(),
chat_template: None,
context_length: 2048,
generation_params: None,
};
let captured = capture_span_metadata("execute:test", &metadata);
assert_eq!(
captured.get("quantization").map(String::as_str),
Some("q4_k_m"),
"stamp must surface the filename-inferred quantization alongside backend"
);
}
struct WireLabelStub(Option<&'static str>);
impl crate::runtime_adapter::llm::LlmBackend for WireLabelStub {
fn name(&self) -> &str {
"wire-label-stub"
}
fn wire_label(&self) -> Option<&'static str> {
self.0
}
fn supported_formats(&self) -> Vec<&'static str> {
vec!["gguf"]
}
fn load(
&mut self,
_config: &crate::runtime_adapter::llm::LlmConfig,
) -> crate::runtime_adapter::llm::LlmResult<()> {
Ok(())
}
fn is_loaded(&self) -> bool {
true
}
fn unload(&mut self) -> crate::runtime_adapter::llm::LlmResult<()> {
Ok(())
}
fn generate(
&self,
_messages: &[crate::runtime_adapter::llm::ChatMessage],
_config: &crate::runtime_adapter::llm::GenerationConfig,
) -> crate::runtime_adapter::llm::LlmResult<crate::runtime_adapter::llm::GenerationOutput>
{
unreachable!("stub backend should not be invoked for inference in this test")
}
fn generate_raw(
&self,
_prompt: &str,
_config: &crate::runtime_adapter::llm::GenerationConfig,
) -> crate::runtime_adapter::llm::LlmResult<crate::runtime_adapter::llm::GenerationOutput>
{
unreachable!("stub backend should not be invoked for inference in this test")
}
}
fn capture_with_runtime_overwrite(
metadata: &ModelMetadata,
wire_label: Option<&'static str>,
) -> HashMap<String, String> {
let adapter = crate::runtime_adapter::llm::LlmRuntimeAdapter::with_backend(Box::new(
WireLabelStub(wire_label),
));
tracing::init_tracing(true);
tracing::reset_tracing();
{
let _guard = tracing::SpanGuard::new("execute:test");
stamp_llm_span_cost_attribution(metadata);
stamp_llm_runtime_backend(&adapter);
}
let json = tracing::get_stages_json();
tracing::reset_tracing();
let span = json["spans"]
.as_array()
.and_then(|spans| {
spans
.iter()
.find(|s| s["name"].as_str() == Some("execute:test"))
})
.expect("span recorded by SpanGuard must be present in stages json");
span["metadata"]
.as_object()
.map(|obj| {
obj.iter()
.filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_owned())))
.collect()
})
.unwrap_or_default()
}
#[test]
fn runtime_wire_label_overwrites_template_default() {
let _lock = GLOBAL_TRACE_LOCK.lock().unwrap();
let captured = capture_with_runtime_overwrite(&gguf_metadata(None), Some("mistralrs"));
assert_eq!(
captured.get("backend").map(String::as_str),
Some("mistralrs"),
"runtime wire label must overwrite the template-derived default"
);
}
#[test]
fn runtime_overwrite_preserves_template_default_when_label_absent() {
let _lock = GLOBAL_TRACE_LOCK.lock().unwrap();
let captured = capture_with_runtime_overwrite(&gguf_metadata(None), None);
assert_eq!(
captured.get("backend").map(String::as_str),
Some("llamacpp"),
"stub backends without a wire label must not erase the template-derived stamp"
);
}
}
}