use super::steps::{PostprocessingStep, PreprocessingStep};
use super::voice::{VoiceConfig, VoiceInfo};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[cfg(feature = "schema")]
use schemars::JsonSchema;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
#[serde(tag = "type")]
pub enum ExecutionTemplate {
Onnx {
model_file: String,
},
SafeTensors {
model_file: String,
#[serde(default)]
architecture: Option<String>,
#[serde(default)]
config_file: Option<String>,
#[serde(default)]
tokenizer_file: Option<String>,
},
CoreMl {
model_file: String,
},
TfLite {
model_file: String,
},
ModelGraph {
stages: Vec<PipelineStage>,
#[serde(default)]
config: HashMap<String, serde_json::Value>,
},
Gguf {
model_file: String,
#[serde(default)]
chat_template: Option<String>,
#[serde(default = "default_context_length")]
context_length: usize,
#[serde(default)]
generation_params: Option<GenerationParams>,
},
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
pub struct GenerationParams {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub stop_sequences: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
pub struct PipelineStage {
pub name: String,
pub model_file: String,
#[serde(default)]
pub execution_mode: ExecutionMode,
pub inputs: Vec<String>,
pub outputs: Vec<String>,
#[serde(default)]
pub config: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
#[serde(tag = "type")]
#[derive(Default)]
pub enum ExecutionMode {
#[default]
SingleShot,
Autoregressive {
max_tokens: usize,
start_token_id: i64,
end_token_id: i64,
#[serde(default)]
repetition_penalty: f32,
},
WhisperDecoder {
max_tokens: usize,
start_token_id: i64,
end_token_id: i64,
language_token_id: i64,
task_token_id: i64,
no_timestamps_token_id: i64,
#[serde(default)]
suppress_tokens: Vec<i64>,
#[serde(default = "default_repetition_penalty")]
repetition_penalty: f32,
},
IterativeRefinement {
num_steps: usize,
#[serde(default)]
schedule: RefinementSchedule,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
#[serde(tag = "type")]
#[derive(Default)]
pub enum RefinementSchedule {
#[default]
Linear,
Cosine,
Custom {
timesteps: Vec<f32>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(JsonSchema))]
pub struct ModelMetadata {
pub model_id: String,
pub version: String,
pub execution_template: ExecutionTemplate,
#[serde(default)]
pub preprocessing: Vec<PreprocessingStep>,
#[serde(default)]
pub postprocessing: Vec<PostprocessingStep>,
pub files: Vec<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub voices: Option<VoiceConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_chunk_chars: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub trim_trailing_samples: Option<usize>,
}
impl ModelMetadata {
pub fn onnx(
model_id: impl Into<String>,
version: impl Into<String>,
model_file: impl Into<String>,
) -> Self {
let model_file = model_file.into();
Self {
model_id: model_id.into(),
version: version.into(),
execution_template: ExecutionTemplate::Onnx {
model_file: model_file.clone(),
},
preprocessing: Vec::new(),
postprocessing: Vec::new(),
files: vec![model_file],
description: None,
metadata: HashMap::new(),
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
}
}
pub fn safetensors(
model_id: impl Into<String>,
version: impl Into<String>,
model_file: impl Into<String>,
architecture: impl Into<String>,
) -> Self {
let model_file = model_file.into();
Self {
model_id: model_id.into(),
version: version.into(),
execution_template: ExecutionTemplate::SafeTensors {
model_file: model_file.clone(),
architecture: Some(architecture.into()),
config_file: None,
tokenizer_file: None,
},
preprocessing: Vec::new(),
postprocessing: Vec::new(),
files: vec![model_file],
description: None,
metadata: HashMap::new(),
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
}
}
pub fn model_graph(
model_id: impl Into<String>,
version: impl Into<String>,
stages: Vec<PipelineStage>,
files: Vec<String>,
) -> Self {
Self {
model_id: model_id.into(),
version: version.into(),
execution_template: ExecutionTemplate::ModelGraph {
stages,
config: HashMap::new(),
},
preprocessing: Vec::new(),
postprocessing: Vec::new(),
files,
description: None,
metadata: HashMap::new(),
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
}
}
pub fn with_preprocessing(mut self, step: PreprocessingStep) -> Self {
self.preprocessing.push(step);
self
}
pub fn with_postprocessing(mut self, step: PostprocessingStep) -> Self {
self.postprocessing.push(step);
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn voice_config(&self) -> Option<&VoiceConfig> {
self.voices.as_ref()
}
pub fn get_voice(&self, voice_id: &str) -> Option<&VoiceInfo> {
self.voices
.as_ref()?
.catalog
.iter()
.find(|v| v.id == voice_id)
}
pub fn default_voice(&self) -> Option<&VoiceInfo> {
let config = self.voices.as_ref()?;
self.get_voice(&config.default)
}
pub fn list_voices(&self) -> Vec<&VoiceInfo> {
self.voices
.as_ref()
.map(|c| c.catalog.iter().collect())
.unwrap_or_default()
}
pub fn has_voices(&self) -> bool {
self.voices.is_some()
}
}
fn default_repetition_penalty() -> f32 {
1.1
}
fn default_context_length() -> usize {
4096
}
pub fn stage_kind_from_task(task: &str) -> Option<&'static str> {
match task {
"speech-recognition" | "speech-to-text" | "asr" => Some("asr"),
"text-to-speech" | "tts" => Some("tts"),
"text-generation" | "chat" | "llm" => Some("llm"),
"translation" => Some("translate"),
"image-classification" | "image-to-text" | "vision" => Some("vision"),
"embedding" | "sentence-embedding" => Some("embed"),
"audio-classification" | "vad" => Some("audio"),
_ => None,
}
}
pub fn normalize_llm_backend_hint(hint: &str) -> Option<&'static str> {
match hint {
"llamacpp" => Some("llamacpp"),
"mistral" | "mistralrs" => Some("mistralrs"),
_ => None,
}
}
pub fn backend_label_from_template(
template: &ExecutionTemplate,
hint: Option<&str>,
) -> Option<&'static str> {
match template {
ExecutionTemplate::Onnx { .. } => Some("ort"),
ExecutionTemplate::SafeTensors { .. } => {
hint.and_then(normalize_llm_backend_hint).or(Some("candle"))
}
ExecutionTemplate::Gguf { .. } => hint
.and_then(normalize_llm_backend_hint)
.or(Some("llamacpp")),
ExecutionTemplate::CoreMl { .. }
| ExecutionTemplate::TfLite { .. }
| ExecutionTemplate::ModelGraph { .. } => None,
}
}
fn normalize_quantization_label(label: &str) -> String {
let lower = label.to_lowercase();
match lower.as_str() {
"f16" => "fp16".to_string(),
"f32" => "fp32".to_string(),
_ => lower,
}
}
fn infer_quantization_from_gguf_filename(filename: &str) -> Option<String> {
let lower = filename.to_lowercase();
for q in &[
"q3_k_l", "q3_k_m", "q3_k_s", "q4_k_m", "q4_k_s", "q5_k_m", "q5_k_s", "q2_k", "q4_0",
"q4_1", "q5_0", "q5_1", "q6_k", "q8_0", "f16", "f32",
] {
if lower.contains(q) {
return Some(normalize_quantization_label(q));
}
}
None
}
pub fn quantization_label_from_metadata(metadata: &ModelMetadata) -> Option<String> {
if let Some(declared) = metadata
.metadata
.get("quantization")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
{
return Some(normalize_quantization_label(declared));
}
if let ExecutionTemplate::Gguf { model_file, .. } = &metadata.execution_template {
return infer_quantization_from_gguf_filename(model_file);
}
None
}
pub fn span_kind_from_template(template: &ExecutionTemplate) -> &'static str {
match template {
ExecutionTemplate::CoreMl { .. } => "gpu",
ExecutionTemplate::SafeTensors { .. } => {
#[cfg(feature = "candle-metal")]
{
"gpu"
}
#[cfg(not(feature = "candle-metal"))]
{
"cpu"
}
}
ExecutionTemplate::Gguf { .. } => {
#[cfg(all(
any(feature = "llm-mistral-metal", feature = "llm-llamacpp"),
target_os = "macos"
))]
{
"gpu"
}
#[cfg(not(all(
any(feature = "llm-mistral-metal", feature = "llm-llamacpp"),
target_os = "macos"
)))]
{
"cpu"
}
}
ExecutionTemplate::Onnx { .. }
| ExecutionTemplate::TfLite { .. }
| ExecutionTemplate::ModelGraph { .. } => "cpu",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_onnx_serialization() {
let metadata = ModelMetadata::onnx("mnist", "1.0", "mnist.onnx")
.with_preprocessing(PreprocessingStep::Normalize {
mean: vec![0.1307],
std: vec![0.3081],
})
.with_postprocessing(PostprocessingStep::Argmax { dim: None });
let json = serde_json::to_string_pretty(&metadata).unwrap();
let parsed: ModelMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.model_id, "mnist");
assert!(json.contains("\"type\": \"Onnx\""));
}
#[test]
fn test_execution_modes() {
let autoregressive = ExecutionMode::Autoregressive {
max_tokens: 100,
start_token_id: 0,
end_token_id: 1,
repetition_penalty: 0.8,
};
let json = serde_json::to_string(&autoregressive).unwrap();
let parsed: ExecutionMode = serde_json::from_str(&json).unwrap();
match parsed {
ExecutionMode::Autoregressive { max_tokens, .. } => assert_eq!(max_tokens, 100),
_ => panic!("Expected autoregressive mode"),
}
}
#[test]
fn backend_label_covers_canonical_runtimes() {
let onnx = ExecutionTemplate::Onnx {
model_file: "m.onnx".into(),
};
assert_eq!(backend_label_from_template(&onnx, None), Some("ort"));
let safe = ExecutionTemplate::SafeTensors {
model_file: "m.safetensors".into(),
architecture: None,
config_file: None,
tokenizer_file: None,
};
assert_eq!(backend_label_from_template(&safe, None), Some("candle"));
assert_eq!(
backend_label_from_template(&safe, Some("mlx")),
Some("candle"),
"deferred mlx hints must not claim an unavailable runtime"
);
let gguf = ExecutionTemplate::Gguf {
model_file: "m.gguf".into(),
chat_template: None,
context_length: 2048,
generation_params: None,
};
assert_eq!(
backend_label_from_template(&gguf, None),
Some("llamacpp"),
"unannotated GGUF bundles must default to the universal llama.cpp runtime"
);
assert_eq!(
backend_label_from_template(&gguf, Some("llamacpp")),
Some("llamacpp")
);
assert_eq!(
backend_label_from_template(&gguf, Some("mistral")),
Some("mistralrs")
);
assert_eq!(
backend_label_from_template(&gguf, Some("mistralrs")),
Some("mistralrs")
);
assert_eq!(
backend_label_from_template(&gguf, Some("mlx")),
Some("llamacpp"),
"deferred mlx hints must reflect the runtime that actually executes the model"
);
}
#[test]
fn normalize_llm_backend_hint_canonicalises_aliases() {
assert_eq!(normalize_llm_backend_hint("mistral"), Some("mistralrs"));
assert_eq!(normalize_llm_backend_hint("mistralrs"), Some("mistralrs"));
assert_eq!(normalize_llm_backend_hint("llamacpp"), Some("llamacpp"));
assert_eq!(normalize_llm_backend_hint("mlx"), None);
assert_eq!(normalize_llm_backend_hint("unknown"), None);
assert_eq!(normalize_llm_backend_hint(""), None);
}
#[test]
fn backend_label_omits_unknown_runtimes() {
let coreml = ExecutionTemplate::CoreMl {
model_file: "m.mlmodel".into(),
};
assert!(backend_label_from_template(&coreml, None).is_none());
let tflite = ExecutionTemplate::TfLite {
model_file: "m.tflite".into(),
};
assert!(backend_label_from_template(&tflite, None).is_none());
let graph = ExecutionTemplate::ModelGraph {
stages: Vec::new(),
config: HashMap::new(),
};
assert!(backend_label_from_template(&graph, None).is_none());
}
#[test]
fn quantization_label_prefers_explicit_metadata_field() {
let mut metadata = ModelMetadata {
model_id: "qwen2.5-0.5b-instruct".into(),
version: "1".into(),
execution_template: ExecutionTemplate::Gguf {
model_file: "qwen2.5-0.5b-instruct-q4_k_m.gguf".into(),
chat_template: None,
context_length: 2048,
generation_params: None,
},
preprocessing: Vec::new(),
postprocessing: Vec::new(),
files: Vec::new(),
description: None,
metadata: HashMap::new(),
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
};
metadata
.metadata
.insert("quantization".into(), serde_json::json!("Q8_0"));
assert_eq!(
quantization_label_from_metadata(&metadata).as_deref(),
Some("q8_0"),
"explicit metadata.quantization must override filename inference and be lowercased"
);
}
#[test]
fn quantization_label_falls_back_to_gguf_filename() {
let metadata = ModelMetadata {
model_id: "qwen2.5-0.5b".into(),
version: "1".into(),
execution_template: ExecutionTemplate::Gguf {
model_file: "qwen2.5-0.5b-instruct-q4_k_m.gguf".into(),
chat_template: None,
context_length: 2048,
generation_params: None,
},
preprocessing: Vec::new(),
postprocessing: Vec::new(),
files: Vec::new(),
description: None,
metadata: HashMap::new(),
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
};
assert_eq!(
quantization_label_from_metadata(&metadata).as_deref(),
Some("q4_k_m")
);
let fp16 = ModelMetadata {
execution_template: ExecutionTemplate::Gguf {
model_file: "tinyllama-1.1b-chat-f16.gguf".into(),
chat_template: None,
context_length: 2048,
generation_params: None,
},
..metadata.clone()
};
assert_eq!(
quantization_label_from_metadata(&fp16).as_deref(),
Some("fp16"),
"GGUF `f16` must rewrite to canonical `fp16`"
);
}
#[test]
fn quantization_label_omits_when_unknown() {
let onnx_no_meta = ModelMetadata {
model_id: "wav2vec2".into(),
version: "1".into(),
execution_template: ExecutionTemplate::Onnx {
model_file: "model.onnx".into(),
},
preprocessing: Vec::new(),
postprocessing: Vec::new(),
files: Vec::new(),
description: None,
metadata: HashMap::new(),
voices: None,
max_chunk_chars: None,
trim_trailing_samples: None,
};
assert!(
quantization_label_from_metadata(&onnx_no_meta).is_none(),
"ONNX with no metadata.quantization must omit the label"
);
let gguf_unknown = ModelMetadata {
execution_template: ExecutionTemplate::Gguf {
model_file: "custom-experimental.gguf".into(),
chat_template: None,
context_length: 2048,
generation_params: None,
},
..onnx_no_meta
};
assert!(
quantization_label_from_metadata(&gguf_unknown).is_none(),
"GGUF with no quantization marker in filename must omit"
);
}
#[test]
fn test_model_metadata_with_voices() {
let json = r#"{
"model_id": "test-tts",
"version": "1.0",
"execution_template": {"type": "Onnx", "model_file": "model.onnx"},
"voices": {
"format": "embedded",
"file": "voices.bin",
"loader": "binary_f32_256",
"default": "voice_1",
"catalog": [{"id": "voice_1", "name": "Voice 1", "index": 0}]
},
"files": ["model.onnx"]
}"#;
let metadata: ModelMetadata = serde_json::from_str(json).unwrap();
assert!(metadata.has_voices());
assert_eq!(metadata.default_voice().unwrap().id, "voice_1");
}
}