mod standard;
mod tts;
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
mod llm;
mod codec_tts;
pub use standard::StandardStrategy;
pub use tts::TtsStrategy;
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
pub use llm::LlmStrategy;
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
#[allow(unused_imports)]
pub use llm::{LlmGenerationParams, LlmInference, LlmModelConfig};
#[allow(unused_imports)]
pub use codec_tts::CodecTtsStrategy;
#[cfg(not(any(feature = "llm-mistral", feature = "llm-llamacpp")))]
mod llm;
#[cfg(not(any(feature = "llm-mistral", feature = "llm-llamacpp")))]
#[allow(unused_imports)]
pub use llm::{LlmGenerationParams, LlmInference, LlmModelConfig, LlmStrategy};
use super::template::ModelMetadata;
use super::types::ExecutorResult;
use crate::ir::Envelope;
use crate::runtime_adapter::ModelRuntime;
use std::collections::HashMap;
use std::path::Path;
pub struct ExecutionContext<'a> {
pub base_path: &'a str,
pub runtimes: &'a mut HashMap<String, Box<dyn ModelRuntime>>,
}
impl<'a> ExecutionContext<'a> {
pub fn resolve_path(&self, file: &str) -> std::path::PathBuf {
Path::new(self.base_path).join(file)
}
pub fn get_runtime(&mut self, name: &str) -> Option<&mut Box<dyn ModelRuntime>> {
self.runtimes.get_mut(name)
}
}
pub trait ExecutionStrategy: Send + Sync {
fn can_handle(&self, metadata: &ModelMetadata) -> bool;
fn execute(
&self,
ctx: &mut ExecutionContext<'_>,
metadata: &ModelMetadata,
input: &Envelope,
) -> ExecutorResult<Envelope>;
fn name(&self) -> &'static str;
}
pub struct StrategyResolver {
strategies: Vec<Box<dyn ExecutionStrategy>>,
}
impl StrategyResolver {
#[allow(clippy::vec_init_then_push)]
pub fn new() -> Self {
let mut strategies: Vec<Box<dyn ExecutionStrategy>> = vec![];
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
{
strategies.push(Box::new(CodecTtsStrategy::new()));
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
{
strategies.push(Box::new(LlmStrategy::new()));
}
strategies.push(Box::new(TtsStrategy::new()));
strategies.push(Box::new(StandardStrategy::new()));
Self { strategies }
}
pub fn resolve(&self, metadata: &ModelMetadata) -> Option<&dyn ExecutionStrategy> {
self.strategies
.iter()
.find(|s| s.can_handle(metadata))
.map(|s| s.as_ref())
}
}
impl Default for StrategyResolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
use crate::execution::template::ExecutionTemplate;
use crate::execution::template::PreprocessingStep;
#[test]
fn test_resolver_selects_tts_for_phonemize() {
let resolver = StrategyResolver::new();
let metadata = ModelMetadata::onnx("test-tts", "1.0", "model.onnx").with_preprocessing(
PreprocessingStep::Phonemize {
tokens_file: "tokens.txt".to_string(),
backend: Default::default(),
dict_file: None,
language: None,
add_padding: true,
normalize_text: false,
silence_tokens: None,
},
);
let strategy = resolver.resolve(&metadata);
assert!(strategy.is_some());
assert_eq!(strategy.unwrap().name(), "tts");
}
#[test]
fn test_resolver_selects_standard_for_onnx() {
let resolver = StrategyResolver::new();
let metadata = ModelMetadata::onnx("test-model", "1.0", "model.onnx");
let strategy = resolver.resolve(&metadata);
assert!(strategy.is_some());
assert_eq!(strategy.unwrap().name(), "standard");
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
#[test]
fn test_resolver_selects_llm_for_gguf() {
let resolver = StrategyResolver::new();
let metadata = ModelMetadata {
model_id: "test-llm".to_string(),
version: "1.0".to_string(),
execution_template: ExecutionTemplate::Gguf {
model_file: "model.gguf".to_string(),
chat_template: None,
context_length: 4096,
generation_params: None,
},
preprocessing: vec![],
postprocessing: vec![],
files: vec!["model.gguf".to_string()],
description: None,
metadata: std::collections::HashMap::new(),
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
};
let strategy = resolver.resolve(&metadata);
assert!(strategy.is_some());
assert_eq!(strategy.unwrap().name(), "llm");
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
#[test]
fn test_resolver_selects_codec_tts_for_gguf_with_codec_decode() {
use crate::execution::template::PostprocessingStep;
let resolver = StrategyResolver::new();
let metadata = ModelMetadata {
model_id: "neutts-nano-q4".to_string(),
version: "1.0".to_string(),
execution_template: ExecutionTemplate::Gguf {
model_file: "model.gguf".to_string(),
chat_template: None,
context_length: 2048,
generation_params: None,
},
preprocessing: vec![PreprocessingStep::PhonemeRaw {
backend: Default::default(),
language: Some("en-us".to_string()),
}],
postprocessing: vec![PostprocessingStep::CodecDecode {
decoder_model: "neucodec-decoder-int8.onnx".to_string(),
sample_rate: 24000,
token_pattern: r"<\|speech_(\d+)\|>".to_string(),
apply_postprocessing: true,
}],
files: vec!["model.gguf".to_string()],
description: None,
metadata: std::collections::HashMap::new(),
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
};
let strategy = resolver.resolve(&metadata);
assert!(strategy.is_some());
assert_eq!(strategy.unwrap().name(), "codec_tts");
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
#[test]
fn test_resolver_selects_llm_not_codec_tts_for_plain_gguf() {
let resolver = StrategyResolver::new();
let metadata = ModelMetadata {
model_id: "plain-llm".to_string(),
version: "1.0".to_string(),
execution_template: ExecutionTemplate::Gguf {
model_file: "model.gguf".to_string(),
chat_template: None,
context_length: 4096,
generation_params: None,
},
preprocessing: vec![],
postprocessing: vec![],
files: vec!["model.gguf".to_string()],
description: None,
metadata: std::collections::HashMap::new(),
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
};
let strategy = resolver.resolve(&metadata);
assert!(strategy.is_some());
assert_eq!(strategy.unwrap().name(), "llm");
}
#[cfg(any(feature = "llm-mistral", feature = "llm-llamacpp"))]
#[test]
fn test_resolver_selects_tts_not_codec_for_onnx_phonemize() {
let resolver = StrategyResolver::new();
let metadata = ModelMetadata::onnx("kokoro-82m", "1.0", "model.onnx").with_preprocessing(
PreprocessingStep::Phonemize {
tokens_file: "tokens.txt".to_string(),
backend: Default::default(),
dict_file: None,
language: None,
add_padding: true,
normalize_text: false,
silence_tokens: None,
},
);
let strategy = resolver.resolve(&metadata);
assert!(strategy.is_some());
assert_eq!(strategy.unwrap().name(), "tts");
}
}