use log::debug;
use std::path::{Path, PathBuf};
use super::template::{
ModelMetadata, VoiceConfig, VoiceFormat, VoiceInfo, VoiceLoader, VoiceSelectionStrategy,
};
use super::types::ExecutorResult;
use crate::ir::Envelope;
use crate::runtime_adapter::AdapterError;
pub const DEFAULT_EMBEDDING_DIM: usize = 256;
#[derive(Debug, Clone, PartialEq)]
struct VoiceComponent {
name: String,
weight: f32,
}
fn parse_compound_voice_id(voice_id: &str) -> Option<Vec<VoiceComponent>> {
if !voice_id.contains('+') {
return None;
}
let parts: Vec<&str> = voice_id.split('+').collect();
if parts.len() < 2 {
return None;
}
let mut components = Vec::with_capacity(parts.len());
for part in &parts {
let part = part.trim();
if part.is_empty() {
return None;
}
if let Some(dot_pos) = part.rfind('.') {
let name = &part[..dot_pos];
let weight_str = &part[dot_pos + 1..];
if name.is_empty() {
return None;
}
if weight_str.len() == 1 {
if let Some(digit) = weight_str.chars().next().and_then(|c| c.to_digit(10)) {
components.push(VoiceComponent {
name: name.to_string(),
weight: digit as f32 * 0.1,
});
continue;
}
}
return None;
}
return None;
}
let total: f32 = components.iter().map(|c| c.weight).sum();
if total > 0.0 && (total - 1.0).abs() > f32::EPSILON {
for component in &mut components {
component.weight /= total;
}
} else if total == 0.0 {
let equal = 1.0 / components.len() as f32;
for component in &mut components {
component.weight = equal;
}
}
Some(components)
}
pub trait VoiceEmbeddingSource: Send + Sync {
fn load_by_index(&self, path: &Path, index: usize) -> Result<Vec<f32>, String>;
fn load_by_name(&self, path: &Path, name: &str) -> Result<Vec<f32>, String>;
fn load_by_name_with_token_length(
&self,
path: &Path,
name: &str,
token_length: usize,
) -> Result<Vec<f32>, String>;
}
pub struct DefaultVoiceSource {
embedding_dim: usize,
}
impl DefaultVoiceSource {
pub fn new(embedding_dim: usize) -> Self {
Self { embedding_dim }
}
}
impl Default for DefaultVoiceSource {
fn default() -> Self {
Self::new(DEFAULT_EMBEDDING_DIM)
}
}
impl VoiceEmbeddingSource for DefaultVoiceSource {
fn load_by_index(&self, path: &Path, index: usize) -> Result<Vec<f32>, String> {
let loader = crate::tts::voice_embedding::VoiceEmbeddingLoader::new(self.embedding_dim);
loader.load(path, index).map_err(|e| e.to_string())
}
fn load_by_name(&self, path: &Path, name: &str) -> Result<Vec<f32>, String> {
let loader = crate::tts::voice_embedding::VoiceEmbeddingLoader::new(self.embedding_dim);
loader
.load_npz_by_name(path, name, None)
.map_err(|e| e.to_string())
}
fn load_by_name_with_token_length(
&self,
path: &Path,
name: &str,
token_length: usize,
) -> Result<Vec<f32>, String> {
let loader = crate::tts::voice_embedding::VoiceEmbeddingLoader::new(self.embedding_dim);
loader
.load_npz_by_name(path, name, Some(token_length))
.map_err(|e| e.to_string())
}
}
pub struct TtsVoiceLoader<S: VoiceEmbeddingSource = DefaultVoiceSource> {
base_path: PathBuf,
source: S,
}
impl TtsVoiceLoader<DefaultVoiceSource> {
pub fn new(base_path: impl Into<PathBuf>) -> Self {
Self {
base_path: base_path.into(),
source: DefaultVoiceSource::default(),
}
}
}
impl<S: VoiceEmbeddingSource> TtsVoiceLoader<S> {
pub fn with_source(base_path: impl Into<PathBuf>, source: S) -> Self {
Self {
base_path: base_path.into(),
source,
}
}
pub fn load(&self, metadata: &ModelMetadata, input: &Envelope) -> ExecutorResult<Vec<f32>> {
self.load_for_token_count(metadata, input, None)
}
pub fn load_for_token_count(
&self,
metadata: &ModelMetadata,
input: &Envelope,
token_count: Option<usize>,
) -> ExecutorResult<Vec<f32>> {
let voice_path = match self.resolve_voice_path(metadata)? {
Some(path) => path,
None => {
debug!(target: "xybrid_core", "No voice file found, using zero embedding");
return Ok(vec![0.0f32; DEFAULT_EMBEDDING_DIM]);
}
};
if !voice_path.exists() {
debug!(target: "xybrid_core", "Voice file not found: {:?}, using zero embedding", voice_path);
return Ok(vec![0.0f32; DEFAULT_EMBEDDING_DIM]);
}
let voice_id = input.metadata.get("voice_id");
if let Some(vid) = voice_id {
if let Some(components) = parse_compound_voice_id(vid) {
debug!(
target: "xybrid_core",
"Compound voice ID detected: {} ({} components)",
vid,
components.len()
);
return self.load_blended_voice(&voice_path, &components, metadata, token_count);
}
}
if let Some(voice_config) = &metadata.voices {
self.load_with_config_and_token_length(
&voice_path,
voice_config,
metadata,
voice_id,
token_count,
)
} else {
self.load_legacy(&voice_path, voice_id)
}
}
fn resolve_voice_path(&self, metadata: &ModelMetadata) -> ExecutorResult<Option<PathBuf>> {
if let Some(voice_config) = &metadata.voices {
match &voice_config.format {
VoiceFormat::Embedded { file, .. } => Ok(Some(self.base_path.join(file))),
VoiceFormat::PrecomputedCodes { .. } => {
Err(AdapterError::InvalidInput(
"PrecomputedCodes voice format uses load_reference_codes(), not the embedding path".to_string(),
))
}
VoiceFormat::PerModel { .. } | VoiceFormat::Cloning { .. } => {
Err(AdapterError::InvalidInput(
"Only embedded voice format is currently supported".to_string(),
))
}
}
} else {
let voices_bin = self.base_path.join("voices.bin");
let voices_npz = self.base_path.join("voices.npz");
if voices_bin.exists() {
Ok(Some(voices_bin))
} else if voices_npz.exists() {
Ok(Some(voices_npz))
} else {
Ok(None)
}
}
}
fn load_with_config(
&self,
voice_path: &Path,
voice_config: &VoiceConfig,
metadata: &ModelMetadata,
voice_id: Option<&String>,
) -> ExecutorResult<Vec<f32>> {
self.load_with_config_and_token_length(voice_path, voice_config, metadata, voice_id, None)
}
fn load_with_config_and_token_length(
&self,
voice_path: &Path,
voice_config: &VoiceConfig,
metadata: &ModelMetadata,
voice_id: Option<&String>,
token_count: Option<usize>,
) -> ExecutorResult<Vec<f32>> {
let voice_info = self.resolve_voice_info(voice_config, metadata, voice_id)?;
debug!(
target: "xybrid_core",
"Loading voice: {} (index: {:?}, strategy: {:?}, token_count: {:?})",
voice_info.id,
voice_info.index,
voice_config.selection_strategy,
token_count
);
match voice_config.selection_strategy {
VoiceSelectionStrategy::TokenLength => {
let token_len = token_count
.map(|tc| tc.saturating_sub(2).min(509))
.unwrap_or(100);
debug!(
target: "xybrid_core",
"TokenLength strategy: loading '{}' at token_length={}",
voice_info.id,
token_len
);
self.source
.load_by_name_with_token_length(voice_path, &voice_info.id, token_len)
.map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to load voice '{}' with token length: {}",
voice_info.id, e
))
})
}
VoiceSelectionStrategy::FixedIndex => {
let is_npz = matches!(
&voice_config.format,
VoiceFormat::Embedded {
loader: VoiceLoader::NumpyNpz,
..
}
);
if is_npz {
self.source
.load_by_name(voice_path, &voice_info.id)
.map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to load voice '{}' by name: {}",
voice_info.id, e
))
})
} else if let Some(index) = voice_info.index {
self.source.load_by_index(voice_path, index).map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to load voice '{}' (index {}): {}",
voice_info.id, index, e
))
})
} else {
self.source
.load_by_name(voice_path, &voice_info.id)
.map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to load voice '{}' by name: {}",
voice_info.id, e
))
})
}
}
}
}
fn resolve_voice_info<'a>(
&self,
voice_config: &'a VoiceConfig,
metadata: &'a ModelMetadata,
voice_id: Option<&String>,
) -> ExecutorResult<&'a VoiceInfo> {
if let Some(vid) = voice_id {
metadata.get_voice(vid).ok_or_else(|| {
let available: Vec<_> =
voice_config.catalog.iter().map(|v| v.id.as_str()).collect();
AdapterError::InvalidInput(format!(
"Voice '{}' not found. Available voices: {:?}",
vid, available
))
})
} else {
metadata.default_voice().ok_or_else(|| {
AdapterError::RuntimeError(format!(
"Default voice '{}' not found in catalog",
voice_config.default
))
})
}
}
fn load_legacy(
&self,
voice_path: &Path,
voice_id: Option<&String>,
) -> ExecutorResult<Vec<f32>> {
if let Some(vid) = voice_id {
if let Ok(index) = vid.parse::<usize>() {
debug!(target: "xybrid_core", "Loading voice by index: {}", index);
self.source.load_by_index(voice_path, index)
} else {
debug!(target: "xybrid_core", "Loading voice by name: {}", vid);
self.source.load_by_name(voice_path, vid)
}
} else {
debug!(target: "xybrid_core", "Loading default voice (index 0)");
self.source.load_by_index(voice_path, 0)
}
.map_err(|e| AdapterError::RuntimeError(format!("Failed to load voice embedding: {}", e)))
}
pub fn load_reference_codes(
&self,
metadata: &ModelMetadata,
voice_id: &str,
) -> ExecutorResult<(Vec<i32>, String)> {
let voice_config = metadata
.voices
.as_ref()
.ok_or_else(|| AdapterError::InvalidInput("No voice config in metadata".to_string()))?;
match &voice_config.format {
VoiceFormat::PrecomputedCodes {
codes_dir,
codes_pattern,
transcript_dir,
transcript_pattern,
} => {
let codes_file = codes_pattern.replace("{voice_id}", voice_id);
let codes_path = self.base_path.join(codes_dir).join(&codes_file);
let transcript_file = transcript_pattern.replace("{voice_id}", voice_id);
let transcript_path = self.base_path.join(transcript_dir).join(&transcript_file);
if !codes_path.exists() {
return Err(AdapterError::RuntimeError(format!(
"Voice codes file not found: {:?}",
codes_path
)));
}
if !transcript_path.exists() {
return Err(AdapterError::RuntimeError(format!(
"Voice transcript file not found: {:?}",
transcript_path
)));
}
let codes_data = std::fs::read(&codes_path).map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to read voice codes {:?}: {}",
codes_path, e
))
})?;
let codes = read_codes_binary(&codes_data)?;
let transcript = std::fs::read_to_string(&transcript_path).map_err(|e| {
AdapterError::RuntimeError(format!(
"Failed to read voice transcript {:?}: {}",
transcript_path, e
))
})?;
debug!(
target: "xybrid_core",
"Loaded reference codes for voice '{}': {} codes, transcript len={}",
voice_id,
codes.len(),
transcript.len()
);
Ok((codes, transcript.trim().to_string()))
}
_ => Err(AdapterError::InvalidInput(
"load_reference_codes requires PrecomputedCodes voice format".to_string(),
)),
}
}
fn load_blended_voice(
&self,
voice_path: &Path,
components: &[VoiceComponent],
metadata: &ModelMetadata,
token_count: Option<usize>,
) -> ExecutorResult<Vec<f32>> {
let mut blended: Option<Vec<f32>> = None;
for component in components {
let voice_id_str = component.name.clone();
let embedding = if let Some(voice_config) = &metadata.voices {
self.load_with_config_and_token_length(
voice_path,
voice_config,
metadata,
Some(&voice_id_str),
token_count,
)?
} else {
self.load_legacy(voice_path, Some(&voice_id_str))?
};
match &mut blended {
None => {
blended = Some(embedding.iter().map(|&v| v * component.weight).collect());
}
Some(acc) => {
if acc.len() != embedding.len() {
return Err(AdapterError::RuntimeError(format!(
"Voice embedding dimension mismatch: expected {}, got {} for '{}'",
acc.len(),
embedding.len(),
component.name
)));
}
for (a, &e) in acc.iter_mut().zip(embedding.iter()) {
*a += e * component.weight;
}
}
}
}
blended
.ok_or_else(|| AdapterError::RuntimeError("No voice components to blend".to_string()))
}
}
fn read_codes_binary(data: &[u8]) -> ExecutorResult<Vec<i32>> {
if data.len() < 4 {
return Err(AdapterError::InvalidInput(
"Voice codes file too small to contain count header".to_string(),
));
}
let count = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize;
let expected = 4 + count * 4;
if data.len() < expected {
return Err(AdapterError::InvalidInput(format!(
"Voice codes file truncated: expected {} bytes for {} codes, got {}",
expected,
count,
data.len()
)));
}
let mut codes = Vec::with_capacity(count);
for i in 0..count {
let offset = 4 + i * 4;
codes.push(i32::from_le_bytes(
data[offset..offset + 4].try_into().unwrap(),
));
}
Ok(codes)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
struct MockVoiceSource {
voices_by_index: HashMap<usize, Vec<f32>>,
voices_by_name: HashMap<String, Vec<f32>>,
}
impl MockVoiceSource {
fn new() -> Self {
Self {
voices_by_index: HashMap::new(),
voices_by_name: HashMap::new(),
}
}
fn with_voice_at_index(mut self, index: usize, embedding: Vec<f32>) -> Self {
self.voices_by_index.insert(index, embedding);
self
}
fn with_voice_by_name(mut self, name: &str, embedding: Vec<f32>) -> Self {
self.voices_by_name.insert(name.to_string(), embedding);
self
}
}
impl VoiceEmbeddingSource for MockVoiceSource {
fn load_by_index(&self, _path: &Path, index: usize) -> Result<Vec<f32>, String> {
self.voices_by_index
.get(&index)
.cloned()
.ok_or_else(|| format!("Voice at index {} not found", index))
}
fn load_by_name(&self, _path: &Path, name: &str) -> Result<Vec<f32>, String> {
self.voices_by_name
.get(name)
.cloned()
.ok_or_else(|| format!("Voice '{}' not found", name))
}
fn load_by_name_with_token_length(
&self,
_path: &Path,
name: &str,
_token_length: usize,
) -> Result<Vec<f32>, String> {
self.load_by_name(_path, name)
}
}
fn create_test_metadata_with_voices() -> ModelMetadata {
use super::super::template::{
VoiceConfig, VoiceFormat, VoiceInfo, VoiceLoader, VoiceSelectionStrategy,
};
let mut metadata = ModelMetadata::onnx("test-tts", "1.0", "model.onnx");
metadata.voices = Some(VoiceConfig {
format: VoiceFormat::Embedded {
file: "voices.bin".to_string(),
loader: VoiceLoader::BinaryF32_256,
},
default: "voice_a".to_string(),
catalog: vec![
VoiceInfo {
id: "voice_a".to_string(),
name: "Voice A".to_string(),
gender: None,
language: None,
style: None,
index: Some(0),
preview_url: None,
},
VoiceInfo {
id: "voice_b".to_string(),
name: "Voice B".to_string(),
gender: None,
language: None,
style: None,
index: Some(1),
preview_url: None,
},
],
selection_strategy: VoiceSelectionStrategy::default(),
});
metadata
}
fn create_test_envelope() -> Envelope {
use crate::ir::EnvelopeKind;
Envelope::new(EnvelopeKind::Text("Hello world".to_string()))
}
fn create_test_envelope_with_voice(voice_id: &str) -> Envelope {
use crate::ir::EnvelopeKind;
let mut metadata = HashMap::new();
metadata.insert("voice_id".to_string(), voice_id.to_string());
Envelope::with_metadata(EnvelopeKind::Text("Hello world".to_string()), metadata)
}
#[test]
fn test_resolve_voice_path_with_config() {
let source = MockVoiceSource::new();
let loader = TtsVoiceLoader::with_source("/models/tts", source);
let metadata = create_test_metadata_with_voices();
let path = loader.resolve_voice_path(&metadata).unwrap();
assert!(path.is_some());
assert!(path.unwrap().ends_with("voices.bin"));
}
#[test]
fn test_resolve_voice_path_no_config_no_files() {
let source = MockVoiceSource::new();
let loader = TtsVoiceLoader::with_source("/nonexistent", source);
let metadata = ModelMetadata::onnx("test", "1.0", "model.onnx");
let path = loader.resolve_voice_path(&metadata).unwrap();
assert!(path.is_none());
}
#[test]
fn test_resolve_voice_info_default() {
let source = MockVoiceSource::new();
let loader = TtsVoiceLoader::with_source("/models", source);
let metadata = create_test_metadata_with_voices();
let voice_config = metadata.voices.as_ref().unwrap();
let info = loader
.resolve_voice_info(voice_config, &metadata, None)
.unwrap();
assert_eq!(info.id, "voice_a");
}
#[test]
fn test_resolve_voice_info_by_id() {
let source = MockVoiceSource::new();
let loader = TtsVoiceLoader::with_source("/models", source);
let metadata = create_test_metadata_with_voices();
let voice_config = metadata.voices.as_ref().unwrap();
let voice_id = "voice_b".to_string();
let info = loader
.resolve_voice_info(voice_config, &metadata, Some(&voice_id))
.unwrap();
assert_eq!(info.id, "voice_b");
assert_eq!(info.index, Some(1));
}
#[test]
fn test_resolve_voice_info_not_found() {
let source = MockVoiceSource::new();
let loader = TtsVoiceLoader::with_source("/models", source);
let metadata = create_test_metadata_with_voices();
let voice_config = metadata.voices.as_ref().unwrap();
let voice_id = "nonexistent".to_string();
let result = loader.resolve_voice_info(voice_config, &metadata, Some(&voice_id));
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("not found"));
assert!(err.contains("voice_a")); }
#[test]
fn test_load_legacy_by_index() {
let embedding = vec![1.0, 2.0, 3.0];
let source = MockVoiceSource::new().with_voice_at_index(5, embedding.clone());
let loader = TtsVoiceLoader::with_source("/models", source);
let voice_id = "5".to_string();
let result = loader
.load_legacy(Path::new("/models/voices.bin"), Some(&voice_id))
.unwrap();
assert_eq!(result, embedding);
}
#[test]
fn test_load_legacy_by_name() {
let embedding = vec![4.0, 5.0, 6.0];
let source = MockVoiceSource::new().with_voice_by_name("bella", embedding.clone());
let loader = TtsVoiceLoader::with_source("/models", source);
let voice_id = "bella".to_string();
let result = loader
.load_legacy(Path::new("/models/voices.npz"), Some(&voice_id))
.unwrap();
assert_eq!(result, embedding);
}
#[test]
fn test_load_legacy_default_index_0() {
let embedding = vec![7.0, 8.0, 9.0];
let source = MockVoiceSource::new().with_voice_at_index(0, embedding.clone());
let loader = TtsVoiceLoader::with_source("/models", source);
let result = loader
.load_legacy(Path::new("/models/voices.bin"), None)
.unwrap();
assert_eq!(result, embedding);
}
#[test]
fn test_load_returns_zero_embedding_when_no_voice_file() {
let source = MockVoiceSource::new();
let loader = TtsVoiceLoader::with_source("/nonexistent/path", source);
let metadata = ModelMetadata::onnx("test", "1.0", "model.onnx");
let input = create_test_envelope();
let result = loader.load(&metadata, &input).unwrap();
assert_eq!(result.len(), DEFAULT_EMBEDDING_DIM);
assert!(result.iter().all(|&v| v == 0.0));
}
#[test]
fn test_parse_compound_two_voices() {
let result = parse_compound_voice_id("af_sarah.4+af_nicole.6").unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].name, "af_sarah");
assert_eq!(result[1].name, "af_nicole");
assert!((result[0].weight - 0.4).abs() < f32::EPSILON);
assert!((result[1].weight - 0.6).abs() < f32::EPSILON);
}
#[test]
fn test_parse_compound_three_voices() {
let result = parse_compound_voice_id("af_heart.3+af_sarah.3+am_adam.4").unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].name, "af_heart");
assert_eq!(result[1].name, "af_sarah");
assert_eq!(result[2].name, "am_adam");
assert!((result[0].weight - 0.3).abs() < f32::EPSILON);
assert!((result[1].weight - 0.3).abs() < f32::EPSILON);
assert!((result[2].weight - 0.4).abs() < f32::EPSILON);
}
#[test]
fn test_parse_compound_weight_normalization() {
let result = parse_compound_voice_id("af_sarah.2+af_nicole.3").unwrap();
assert_eq!(result.len(), 2);
assert!((result[0].weight - 0.4).abs() < 1e-6);
assert!((result[1].weight - 0.6).abs() < 1e-6);
}
#[test]
fn test_parse_compound_zero_weights_distribute_equally() {
let result = parse_compound_voice_id("af_sarah.0+af_nicole.0").unwrap();
assert_eq!(result.len(), 2);
assert!((result[0].weight - 0.5).abs() < f32::EPSILON);
assert!((result[1].weight - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_parse_single_voice_returns_none() {
assert!(parse_compound_voice_id("af_sarah").is_none());
}
#[test]
fn test_parse_malformed_no_weight_returns_none() {
assert!(parse_compound_voice_id("af_sarah+af_nicole").is_none());
}
#[test]
fn test_parse_malformed_multi_digit_weight_returns_none() {
assert!(parse_compound_voice_id("af_sarah.40+af_nicole.60").is_none());
}
#[test]
fn test_parse_malformed_empty_name_returns_none() {
assert!(parse_compound_voice_id(".4+af_nicole.6").is_none());
}
#[test]
fn test_parse_malformed_empty_part_returns_none() {
assert!(parse_compound_voice_id("af_sarah.4++af_nicole.6").is_none());
}
#[test]
fn test_parse_malformed_non_digit_weight_returns_none() {
assert!(parse_compound_voice_id("af_sarah.x+af_nicole.6").is_none());
}
#[test]
fn test_load_blended_voice_two_voices() {
let source = MockVoiceSource::new()
.with_voice_by_name("voice_a", vec![1.0, 0.0, 0.0])
.with_voice_by_name("voice_b", vec![0.0, 1.0, 0.0]);
let loader = TtsVoiceLoader::with_source("/models", source);
let components = vec![
VoiceComponent {
name: "voice_a".to_string(),
weight: 0.4,
},
VoiceComponent {
name: "voice_b".to_string(),
weight: 0.6,
},
];
let metadata = ModelMetadata::onnx("test", "1.0", "model.onnx");
let result = loader
.load_blended_voice(
Path::new("/models/voices.npz"),
&components,
&metadata,
None,
)
.unwrap();
assert_eq!(result.len(), 3);
assert!((result[0] - 0.4).abs() < f32::EPSILON);
assert!((result[1] - 0.6).abs() < f32::EPSILON);
assert!((result[2] - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_load_blended_voice_three_voices() {
let source = MockVoiceSource::new()
.with_voice_by_name("a", vec![1.0, 0.0, 0.0])
.with_voice_by_name("b", vec![0.0, 1.0, 0.0])
.with_voice_by_name("c", vec![0.0, 0.0, 1.0]);
let loader = TtsVoiceLoader::with_source("/models", source);
let components = vec![
VoiceComponent {
name: "a".to_string(),
weight: 0.2,
},
VoiceComponent {
name: "b".to_string(),
weight: 0.3,
},
VoiceComponent {
name: "c".to_string(),
weight: 0.5,
},
];
let metadata = ModelMetadata::onnx("test", "1.0", "model.onnx");
let result = loader
.load_blended_voice(
Path::new("/models/voices.npz"),
&components,
&metadata,
None,
)
.unwrap();
assert_eq!(result.len(), 3);
assert!((result[0] - 0.2).abs() < f32::EPSILON);
assert!((result[1] - 0.3).abs() < f32::EPSILON);
assert!((result[2] - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_load_blended_voice_dimension_mismatch() {
let source = MockVoiceSource::new()
.with_voice_by_name("a", vec![1.0, 0.0])
.with_voice_by_name("b", vec![0.0, 1.0, 0.0]); let loader = TtsVoiceLoader::with_source("/models", source);
let components = vec![
VoiceComponent {
name: "a".to_string(),
weight: 0.5,
},
VoiceComponent {
name: "b".to_string(),
weight: 0.5,
},
];
let metadata = ModelMetadata::onnx("test", "1.0", "model.onnx");
let result = loader.load_blended_voice(
Path::new("/models/voices.npz"),
&components,
&metadata,
None,
);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("dimension mismatch"));
}
}