use log::{debug, info};
use super::{ExecutionContext, ExecutionStrategy};
use crate::execution::modes::execute_bert_inference;
use crate::execution::session_factory::OnnxSessionFactory;
use crate::execution::template::{ExecutionTemplate, ModelMetadata, PreprocessingStep};
use crate::execution::types::{ExecutorResult, PreprocessedData, RawOutputs};
use crate::execution::{postprocessing, preprocessing};
use crate::ir::Envelope;
use crate::runtime_adapter::onnx::{ExecutionProviderKind, SessionOptions};
use crate::runtime_adapter::AdapterError;
use crate::tracing as xybrid_trace;
pub struct StandardStrategy;
impl StandardStrategy {
pub fn new() -> Self {
Self
}
fn is_tts_model(metadata: &ModelMetadata) -> bool {
metadata
.preprocessing
.iter()
.any(|step| matches!(step, PreprocessingStep::Phonemize { .. }))
}
fn is_llm_model(metadata: &ModelMetadata) -> bool {
matches!(metadata.execution_template, ExecutionTemplate::Gguf { .. })
}
fn is_model_graph(metadata: &ModelMetadata) -> bool {
matches!(
metadata.execution_template,
ExecutionTemplate::ModelGraph { .. }
)
}
fn get_runtime_info(metadata: &ModelMetadata) -> ExecutorResult<(&'static str, &str)> {
match &metadata.execution_template {
ExecutionTemplate::Onnx { model_file } => Ok(("onnx", model_file)),
ExecutionTemplate::SafeTensors { model_file, .. } => Ok(("candle", model_file)),
ExecutionTemplate::CoreMl { model_file } => Ok(("coreml", model_file)),
ExecutionTemplate::TfLite { model_file } => Ok(("tflite", model_file)),
ExecutionTemplate::ModelGraph { .. } => Err(AdapterError::InvalidInput(
"ModelGraph should use PipelineStrategy".to_string(),
)),
ExecutionTemplate::Gguf { .. } => Err(AdapterError::InvalidInput(
"GGUF models should use LlmStrategy".to_string(),
)),
}
}
fn run_preprocessing(
&self,
ctx: &ExecutionContext<'_>,
metadata: &ModelMetadata,
input: &Envelope,
) -> ExecutorResult<PreprocessedData> {
if metadata.preprocessing.is_empty() {
debug!(target: "xybrid_core", "No preprocessing steps configured");
return PreprocessedData::from_envelope(input);
}
info!(
target: "xybrid_core",
"Running {} preprocessing step(s)",
metadata.preprocessing.len()
);
let _preprocess_span = xybrid_trace::SpanGuard::new("preprocessing");
xybrid_trace::add_metadata("steps", metadata.preprocessing.len().to_string());
let mut data = PreprocessedData::from_envelope(input)?;
for step in &metadata.preprocessing {
let step_name = step.step_name();
debug!(target: "xybrid_core", "Applying preprocessing: {}", step_name);
let _step_span = xybrid_trace::SpanGuard::new(format!("preprocessing:{}", step_name));
data = preprocessing::apply_preprocessing_step(step, data, input, ctx.base_path)?;
}
debug!(target: "xybrid_core", "Preprocessing complete");
Ok(data)
}
fn run_postprocessing(
&self,
ctx: &ExecutionContext<'_>,
metadata: &ModelMetadata,
outputs: RawOutputs,
) -> ExecutorResult<Envelope> {
if metadata.postprocessing.is_empty() {
debug!(target: "xybrid_core", "No postprocessing steps configured");
return outputs.to_envelope();
}
info!(
target: "xybrid_core",
"Running {} postprocessing step(s)",
metadata.postprocessing.len()
);
let _postprocess_span = xybrid_trace::SpanGuard::new("postprocessing");
xybrid_trace::add_metadata("steps", metadata.postprocessing.len().to_string());
let mut data = outputs;
for step in &metadata.postprocessing {
let step_name = step.step_name();
debug!(target: "xybrid_core", "Applying postprocessing: {}", step_name);
let _step_span = xybrid_trace::SpanGuard::new(format!("postprocessing:{}", step_name));
data = postprocessing::apply_postprocessing_step(step, data, ctx.base_path)?;
}
debug!(target: "xybrid_core", "Postprocessing complete");
data.to_envelope()
}
fn execute_bert(
&self,
_ctx: &ExecutionContext<'_>,
_metadata: &ModelMetadata,
preprocessed: &PreprocessedData,
model_path: &std::path::Path,
) -> ExecutorResult<Envelope> {
debug!(target: "xybrid_core", "Detected BERT-style inference (token IDs)");
let (ids, attention_mask, token_type_ids) = preprocessed
.as_token_ids()
.ok_or_else(|| AdapterError::InvalidInput("Expected token IDs".to_string()))?;
let session = OnnxSessionFactory::create_session(
model_path,
ExecutionProviderKind::Cpu,
SessionOptions::default(),
)?;
let raw_outputs = execute_bert_inference(&session, ids, attention_mask, token_type_ids)?;
crate::runtime_adapter::tensor_utils::tensors_to_envelope(
&raw_outputs,
session.output_names(),
)
}
fn execute_runtime(
&self,
ctx: &mut ExecutionContext<'_>,
_metadata: &ModelMetadata,
preprocessed: PreprocessedData,
runtime_type: &str,
model_path: &std::path::Path,
) -> ExecutorResult<Envelope> {
debug!(target: "xybrid_core", "Using standard execution path");
let runtime_input = preprocessed.to_envelope()?;
let runtime = ctx.get_runtime(runtime_type).ok_or_else(|| {
AdapterError::RuntimeError(format!("Runtime '{}' not configured", runtime_type))
})?;
debug!(target: "xybrid_core", "Loading model: {:?}", model_path);
runtime
.load(model_path)
.map_err(|e| AdapterError::RuntimeError(format!("Load failed: {}", e)))?;
debug!(target: "xybrid_core", "Running inference");
runtime.execute(&runtime_input)
}
}
impl Default for StandardStrategy {
fn default() -> Self {
Self::new()
}
}
impl ExecutionStrategy for StandardStrategy {
fn can_handle(&self, metadata: &ModelMetadata) -> bool {
!Self::is_tts_model(metadata)
&& !Self::is_llm_model(metadata)
&& !Self::is_model_graph(metadata)
}
fn execute(
&self,
ctx: &mut ExecutionContext<'_>,
metadata: &ModelMetadata,
input: &Envelope,
) -> ExecutorResult<Envelope> {
let _span = xybrid_trace::SpanGuard::new("standard_execution");
let (runtime_type, model_file) = Self::get_runtime_info(metadata)?;
let model_path = ctx.resolve_path(model_file);
debug!(
target: "xybrid_core",
"Using {} runtime with model: {}",
runtime_type,
model_file
);
let preprocessed = self.run_preprocessing(ctx, metadata, input)?;
let result_envelope = if preprocessed.is_token_ids() {
self.execute_bert(ctx, metadata, &preprocessed, &model_path)?
} else {
self.execute_runtime(ctx, metadata, preprocessed, runtime_type, &model_path)?
};
let raw_outputs = RawOutputs::from_envelope(&result_envelope)?;
let result = self.run_postprocessing(ctx, metadata, raw_outputs)?;
info!(
target: "xybrid_core",
"Model execution complete: {} -> {}",
metadata.model_id,
result.kind_str()
);
Ok(result)
}
fn name(&self) -> &'static str {
"standard"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_can_handle_onnx() {
let strategy = StandardStrategy::new();
let metadata = ModelMetadata::onnx("test-model", "1.0", "model.onnx");
assert!(strategy.can_handle(&metadata));
}
#[test]
fn test_cannot_handle_tts() {
let strategy = StandardStrategy::new();
let metadata = ModelMetadata::onnx("test-tts", "1.0", "model.onnx").with_preprocessing(
PreprocessingStep::Phonemize {
tokens_file: "tokens.txt".to_string(),
backend: Default::default(),
dict_file: None,
language: None,
add_padding: true,
normalize_text: false,
silence_tokens: None,
},
);
assert!(!strategy.can_handle(&metadata));
}
#[test]
fn test_get_runtime_info_onnx() {
let metadata = ModelMetadata::onnx("test", "1.0", "model.onnx");
let (runtime, file) = StandardStrategy::get_runtime_info(&metadata).unwrap();
assert_eq!(runtime, "onnx");
assert_eq!(file, "model.onnx");
}
#[test]
fn test_get_runtime_info_safetensors() {
let metadata = ModelMetadata::safetensors("test", "1.0", "model.safetensors", "whisper");
let (runtime, file) = StandardStrategy::get_runtime_info(&metadata).unwrap();
assert_eq!(runtime, "candle");
assert_eq!(file, "model.safetensors");
}
}