1use std::collections::{BTreeMap, HashMap};
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use candle_nn::VarBuilder;
12use tracing::info;
13
14use crate::device::DeviceSelection;
15use crate::error::TtsError;
16use crate::models::ModelType;
17
18fn normalize_asset_path(path: impl AsRef<str>) -> String {
19 path.as_ref()
20 .replace('\\', "/")
21 .trim_start_matches("./")
22 .trim_start_matches('/')
23 .to_string()
24}
25
26#[derive(Debug, Clone)]
28pub enum ModelAsset {
29 Path(PathBuf),
30 Bytes { name: String, data: Arc<[u8]> },
31}
32
33impl ModelAsset {
34 pub fn from_path(path: impl Into<PathBuf>) -> Self {
35 Self::Path(path.into())
36 }
37
38 pub fn from_bytes(name: impl Into<String>, bytes: impl Into<Vec<u8>>) -> Self {
39 Self::Bytes {
40 name: normalize_asset_path(name.into()),
41 data: Arc::from(bytes.into()),
42 }
43 }
44
45 pub fn as_path(&self) -> Option<&Path> {
46 match self {
47 Self::Path(path) => Some(path),
48 Self::Bytes { .. } => None,
49 }
50 }
51
52 pub fn file_name(&self) -> Option<&str> {
53 match self {
54 Self::Path(path) => path.file_name().and_then(|name| name.to_str()),
55 Self::Bytes { name, .. } => {
56 Path::new(name).file_name().and_then(|value| value.to_str())
57 }
58 }
59 }
60
61 pub fn extension(&self) -> Option<&str> {
62 match self {
63 Self::Path(path) => path.extension().and_then(|ext| ext.to_str()),
64 Self::Bytes { name, .. } => Path::new(name).extension().and_then(|ext| ext.to_str()),
65 }
66 }
67
68 pub fn display_name(&self) -> String {
69 match self {
70 Self::Path(path) => path.display().to_string(),
71 Self::Bytes { name, .. } => name.clone(),
72 }
73 }
74
75 pub fn read_bytes(&self) -> Result<Arc<[u8]>, TtsError> {
76 match self {
77 Self::Path(path) => std::fs::read(path).map(Arc::from).map_err(TtsError::from),
78 Self::Bytes { data, .. } => Ok(data.clone()),
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub enum ModelAssetDir {
86 Path(PathBuf),
87 Bytes(BTreeMap<String, Arc<[u8]>>),
88}
89
90impl ModelAssetDir {
91 pub fn from_path(path: impl Into<PathBuf>) -> Self {
92 Self::Path(path.into())
93 }
94
95 pub fn from_bytes(entries: BTreeMap<String, Arc<[u8]>>) -> Self {
96 Self::Bytes(entries)
97 }
98
99 pub fn load_file(&self, name: &str) -> Result<ModelAsset, TtsError> {
100 match self {
101 Self::Path(path) => {
102 let full_path = path.join(name);
103 if !full_path.exists() {
104 return Err(TtsError::FileMissing(format!(
105 "{} in {}",
106 name,
107 path.display()
108 )));
109 }
110 Ok(ModelAsset::from_path(full_path))
111 }
112 Self::Bytes(entries) => entries
113 .get(name)
114 .cloned()
115 .map(|data| ModelAsset::Bytes {
116 name: name.to_string(),
117 data,
118 })
119 .ok_or_else(|| TtsError::FileMissing(name.to_string())),
120 }
121 }
122
123 pub fn file_names(&self) -> Result<Vec<String>, TtsError> {
124 match self {
125 Self::Path(path) => {
126 let mut names = Vec::new();
127 for entry in std::fs::read_dir(path)? {
128 let entry = entry?;
129 let Some(name) = entry.file_name().to_str().map(str::to_string) else {
130 continue;
131 };
132 names.push(name);
133 }
134 names.sort();
135 Ok(names)
136 }
137 Self::Bytes(entries) => Ok(entries.keys().cloned().collect()),
138 }
139 }
140}
141
142#[derive(Debug, Clone, Default)]
144pub struct ModelAssetBundle {
145 entries: BTreeMap<String, Arc<[u8]>>,
146}
147
148impl ModelAssetBundle {
149 pub fn new() -> Self {
150 Self::default()
151 }
152
153 pub fn insert_bytes(
154 &mut self,
155 relative_path: impl Into<String>,
156 bytes: impl Into<Vec<u8>>,
157 ) -> &mut Self {
158 let relative_path = normalize_asset_path(relative_path.into());
159 self.entries.insert(relative_path, Arc::from(bytes.into()));
160 self
161 }
162
163 pub fn with_bytes(
164 mut self,
165 relative_path: impl Into<String>,
166 bytes: impl Into<Vec<u8>>,
167 ) -> Self {
168 self.insert_bytes(relative_path, bytes);
169 self
170 }
171
172 pub fn is_empty(&self) -> bool {
173 self.entries.is_empty()
174 }
175
176 fn get(&self, relative_path: &str) -> Option<ModelAsset> {
177 let relative_path = normalize_asset_path(relative_path);
178 self.entries
179 .get(&relative_path)
180 .cloned()
181 .map(|data| ModelAsset::Bytes {
182 name: relative_path,
183 data,
184 })
185 }
186
187 fn collect_directory(&self, prefix: &str) -> Option<ModelAssetDir> {
188 let prefix = normalize_asset_path(prefix);
189 let prefix = if prefix.ends_with('/') {
190 prefix
191 } else {
192 format!("{prefix}/")
193 };
194
195 let mut entries = BTreeMap::new();
196 for (path, data) in &self.entries {
197 let Some(rest) = path.strip_prefix(&prefix) else {
198 continue;
199 };
200 if rest.is_empty() || rest.contains('/') {
201 continue;
202 }
203 entries.insert(rest.to_string(), data.clone());
204 }
205
206 if entries.is_empty() {
207 None
208 } else {
209 Some(ModelAssetDir::from_bytes(entries))
210 }
211 }
212
213 fn discover_sharded_weights(&self, prefix: &str) -> Vec<ModelAsset> {
214 let prefix = normalize_asset_path(prefix);
215 let prefix = if prefix.is_empty() {
216 String::new()
217 } else if prefix.ends_with('/') {
218 prefix
219 } else {
220 format!("{prefix}/")
221 };
222
223 let mut shards = self
224 .entries
225 .iter()
226 .filter_map(|(path, data)| {
227 let rest = if prefix.is_empty() {
228 path.as_str()
229 } else {
230 path.strip_prefix(&prefix)?
231 };
232 if rest.contains('/')
233 || !rest.starts_with("model-")
234 || !rest.ends_with(".safetensors")
235 {
236 return None;
237 }
238 Some(ModelAsset::Bytes {
239 name: path.clone(),
240 data: data.clone(),
241 })
242 })
243 .collect::<Vec<_>>();
244 shards.sort_by_key(ModelAsset::display_name);
245 shards
246 }
247
248 fn discover_pth_weights(&self, prefix: &str) -> Vec<ModelAsset> {
249 let prefix = normalize_asset_path(prefix);
250 let prefix = if prefix.is_empty() {
251 String::new()
252 } else if prefix.ends_with('/') {
253 prefix
254 } else {
255 format!("{prefix}/")
256 };
257
258 let mut weights = self
259 .entries
260 .iter()
261 .filter_map(|(path, data)| {
262 let rest = if prefix.is_empty() {
263 path.as_str()
264 } else {
265 path.strip_prefix(&prefix)?
266 };
267 if rest.contains('/') || !rest.ends_with(".pth") {
268 return None;
269 }
270 Some(ModelAsset::Bytes {
271 name: path.clone(),
272 data: data.clone(),
273 })
274 })
275 .collect::<Vec<_>>();
276 weights.sort_by_key(ModelAsset::display_name);
277 weights
278 }
279}
280
281#[derive(Debug, Clone, Default)]
304pub struct ModelFiles {
305 pub config: Option<ModelAsset>,
315
316 pub tokenizer: Option<ModelAsset>,
326
327 pub weights: Vec<ModelAsset>,
339
340 pub voices_dir: Option<ModelAssetDir>,
352
353 pub speech_tokenizer_weights: Vec<ModelAsset>,
370
371 pub speech_tokenizer_config: Option<ModelAsset>,
380
381 pub generation_config: Option<ModelAsset>,
389
390 pub preprocessor_config: Option<ModelAsset>,
395}
396
397impl ModelFiles {
398 pub fn fill_from_directory(&mut self, dir: &Path) {
401 if self.config.is_none() {
403 let p = dir.join("config.json");
404 if p.exists() {
405 info!("Auto-discovered config: {}", p.display());
406 self.config = Some(ModelAsset::from_path(p));
407 } else {
408 let p = dir.join("params.json");
409 if p.exists() {
410 info!("Auto-discovered config: {}", p.display());
411 self.config = Some(ModelAsset::from_path(p));
412 }
413 }
414 }
415
416 if self.tokenizer.is_none() {
418 let p = dir.join("tokenizer.json");
419 if p.exists() {
420 info!("Auto-discovered tokenizer: {}", p.display());
421 self.tokenizer = Some(ModelAsset::from_path(p));
422 } else {
423 let p = dir.join("tekken.json");
424 if p.exists() {
425 info!("Auto-discovered tokenizer: {}", p.display());
426 self.tokenizer = Some(ModelAsset::from_path(p));
427 }
428 }
429 }
430
431 if self.weights.is_empty() {
433 let single = dir.join("model.safetensors");
434 if single.exists() {
435 info!("Auto-discovered single weight file");
436 self.weights.push(ModelAsset::from_path(single));
437 } else {
438 let single = dir.join("consolidated.safetensors");
439 if single.exists() {
440 info!("Auto-discovered single weight file");
441 self.weights.push(ModelAsset::from_path(single));
442 } else {
443 self.discover_sharded_weights(dir);
444 }
445 }
446 if self.weights.is_empty() {
448 self.discover_pth_weights(dir);
449 }
450 }
451
452 if self.voices_dir.is_none() {
454 let p = dir.join("voices");
455 if p.is_dir() {
456 info!("Auto-discovered voices dir: {}", p.display());
457 self.voices_dir = Some(ModelAssetDir::from_path(p));
458 } else {
459 let p = dir.join("voice_embedding");
460 if p.is_dir() {
461 info!("Auto-discovered voices dir: {}", p.display());
462 self.voices_dir = Some(ModelAssetDir::from_path(p));
463 }
464 }
465 }
466
467 if self.generation_config.is_none() {
469 let p = dir.join("generation_config.json");
470 if p.exists() {
471 info!("Auto-discovered generation config: {}", p.display());
472 self.generation_config = Some(ModelAsset::from_path(p));
473 }
474 }
475
476 if self.preprocessor_config.is_none() {
477 let p = dir.join("preprocessor_config.json");
478 if p.exists() {
479 info!("Auto-discovered preprocessor config: {}", p.display());
480 self.preprocessor_config = Some(ModelAsset::from_path(p));
481 }
482 }
483
484 for nested_dir_name in ["audio_tokenizer", "speech_tokenizer"] {
485 let nested_dir = dir.join(nested_dir_name);
486 if !nested_dir.is_dir() {
487 continue;
488 }
489
490 if self.speech_tokenizer_config.is_none() {
491 let p = nested_dir.join("config.json");
492 if p.exists() {
493 info!(
494 "Auto-discovered {} config: {}",
495 nested_dir_name,
496 p.display()
497 );
498 self.speech_tokenizer_config = Some(ModelAsset::from_path(p));
499 }
500 }
501
502 if self.speech_tokenizer_weights.is_empty() {
503 let single = nested_dir.join("model.safetensors");
504 if single.exists() {
505 info!("Auto-discovered {} weight file", nested_dir_name);
506 self.speech_tokenizer_weights
507 .push(ModelAsset::from_path(single));
508 } else {
509 let mut shards = Self::discover_sharded_weights_in_dir(&nested_dir);
510 if !shards.is_empty() {
511 info!(
512 "Auto-discovered {} {} weight shards",
513 shards.len(),
514 nested_dir_name
515 );
516 self.speech_tokenizer_weights.append(&mut shards);
517 }
518 }
519 }
520 }
521 }
522
523 pub fn fill_from_asset_bundle(&mut self, bundle: &ModelAssetBundle) {
525 if self.config.is_none() {
526 self.config = bundle
527 .get("config.json")
528 .or_else(|| bundle.get("params.json"));
529 }
530
531 if self.tokenizer.is_none() {
532 self.tokenizer = bundle
533 .get("tokenizer.json")
534 .or_else(|| bundle.get("tekken.json"));
535 }
536
537 if self.weights.is_empty() {
538 if let Some(asset) = bundle.get("model.safetensors") {
539 self.weights.push(asset);
540 } else if let Some(asset) = bundle.get("consolidated.safetensors") {
541 self.weights.push(asset);
542 } else {
543 self.weights = bundle.discover_sharded_weights("");
544 }
545 if self.weights.is_empty() {
546 self.weights = bundle.discover_pth_weights("");
547 }
548 }
549
550 if self.voices_dir.is_none() {
551 self.voices_dir = bundle
552 .collect_directory("voices")
553 .or_else(|| bundle.collect_directory("voice_embedding"));
554 }
555
556 if self.generation_config.is_none() {
557 self.generation_config = bundle.get("generation_config.json");
558 }
559
560 if self.preprocessor_config.is_none() {
561 self.preprocessor_config = bundle.get("preprocessor_config.json");
562 }
563
564 for nested_dir_name in ["audio_tokenizer", "speech_tokenizer"] {
565 if self.speech_tokenizer_config.is_none() {
566 self.speech_tokenizer_config =
567 bundle.get(format!("{nested_dir_name}/config.json").as_str());
568 }
569
570 if self.speech_tokenizer_weights.is_empty() {
571 if let Some(asset) =
572 bundle.get(format!("{nested_dir_name}/model.safetensors").as_str())
573 {
574 self.speech_tokenizer_weights.push(asset);
575 } else {
576 self.speech_tokenizer_weights =
577 bundle.discover_sharded_weights(nested_dir_name);
578 }
579 }
580 }
581 }
582
583 fn discover_pth_weights(&mut self, dir: &Path) {
585 let Ok(entries) = std::fs::read_dir(dir) else {
586 return;
587 };
588
589 let mut pth_files: Vec<ModelAsset> = entries
590 .filter_map(|e| e.ok())
591 .map(|e| e.path())
592 .filter(|p| {
593 p.extension()
594 .and_then(|ext| ext.to_str())
595 .is_some_and(|ext| ext == "pth")
596 })
597 .map(ModelAsset::from_path)
598 .collect();
599
600 if !pth_files.is_empty() {
601 pth_files.sort_by_key(ModelAsset::display_name);
602 info!("Auto-discovered {} .pth weight file(s)", pth_files.len());
603 self.weights = pth_files;
604 }
605 }
606
607 fn discover_sharded_weights(&mut self, dir: &Path) {
609 let shards = Self::discover_sharded_weights_in_dir(dir);
610
611 if !shards.is_empty() {
612 info!("Auto-discovered {} weight shards", shards.len());
613 self.weights = shards;
614 }
615 }
616
617 fn discover_sharded_weights_in_dir(dir: &Path) -> Vec<ModelAsset> {
618 let Ok(entries) = std::fs::read_dir(dir) else {
619 return Vec::new();
620 };
621
622 let mut shards: Vec<ModelAsset> = entries
623 .filter_map(|e| e.ok())
624 .map(|e| e.path())
625 .filter(|p| {
626 p.file_name()
627 .and_then(|n| n.to_str())
628 .is_some_and(|n| n.starts_with("model-") && n.ends_with(".safetensors"))
629 })
630 .map(ModelAsset::from_path)
631 .collect();
632 shards.sort_by_key(ModelAsset::display_name);
633 shards
634 }
635
636 pub fn load_safetensors_vb(
643 assets: &[ModelAsset],
644 dtype: candle_core::DType,
645 device: &candle_core::Device,
646 ) -> Result<VarBuilder<'static>, TtsError> {
647 if assets.is_empty() {
648 return Err(TtsError::FileMissing("safetensors weight files".into()));
649 }
650
651 if assets.len() == 1 {
652 if let Some(path) = assets[0].as_path() {
653 let data = std::fs::read(path).map_err(|e| {
654 TtsError::WeightLoadError(format!("Failed to read {}: {}", path.display(), e))
655 })?;
656 return VarBuilder::from_buffered_safetensors(data, dtype, device)
657 .map_err(|e| TtsError::WeightLoadError(e.to_string()));
658 }
659 }
660
661 let mut all_tensors: HashMap<String, candle_core::Tensor> = HashMap::new();
663 for asset in assets {
664 let data = asset.read_bytes().map_err(|e| {
665 TtsError::WeightLoadError(format!("Failed to read {}: {}", asset.display_name(), e))
666 })?;
667 let tensors = safetensors::SafeTensors::deserialize(&data).map_err(|e| {
668 TtsError::WeightLoadError(format!(
669 "Failed to parse {}: {}",
670 asset.display_name(),
671 e
672 ))
673 })?;
674 for (name, view) in tensors.tensors() {
675 let native_dtype = match view.dtype() {
677 safetensors::Dtype::F16 => candle_core::DType::F16,
678 safetensors::Dtype::BF16 => candle_core::DType::BF16,
679 safetensors::Dtype::F32 => candle_core::DType::F32,
680 safetensors::Dtype::F64 => candle_core::DType::F64,
681 safetensors::Dtype::I64 => candle_core::DType::I64,
682 safetensors::Dtype::I32 => candle_core::DType::I64, safetensors::Dtype::U32 => candle_core::DType::U32,
684 safetensors::Dtype::U8 => candle_core::DType::U8,
685 _ => candle_core::DType::F32, };
687
688 let tensor = candle_core::Tensor::from_raw_buffer(
690 view.data(),
691 native_dtype,
692 view.shape(),
693 device,
694 )
695 .map_err(|e| {
696 TtsError::WeightLoadError(format!("Failed to load tensor '{}': {}", name, e))
697 })?;
698
699 let tensor = if native_dtype != dtype {
701 tensor.to_dtype(dtype).map_err(|e| {
702 TtsError::WeightLoadError(format!(
703 "Failed to convert tensor '{}' to {:?}: {}",
704 name, dtype, e
705 ))
706 })?
707 } else {
708 tensor
709 };
710
711 all_tensors.insert(name, tensor);
712 }
713 }
714
715 Ok(VarBuilder::from_tensors(all_tensors, dtype, device))
716 }
717
718 #[cfg(feature = "download")]
722 pub fn fill_from_hf(
723 &mut self,
724 model_id: &str,
725 model_type: ModelType,
726 bearer_token: Option<&str>,
727 ) -> Result<(), TtsError> {
728 use crate::download::download_file_with_token;
729
730 let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
731
732 if self.config.is_none() {
734 let config_name = if model_type == ModelType::Voxtral {
735 "params.json"
736 } else {
737 "config.json"
738 };
739 info!("Downloading {} from {}", config_name, model_id);
740 self.config = Some(ModelAsset::from_path(download(model_id, config_name)?));
741 }
742
743 if model_type != ModelType::Kokoro && self.tokenizer.is_none() {
745 let tokenizer_name = if model_type == ModelType::Voxtral {
746 "tekken.json"
747 } else {
748 "tokenizer.json"
749 };
750 info!("Downloading {} from {}", tokenizer_name, model_id);
751 match download(model_id, tokenizer_name) {
752 Ok(p) => self.tokenizer = Some(ModelAsset::from_path(p)),
753 Err(_) => {
754 if model_type == ModelType::Voxtral {
755 return Err(TtsError::FileMissing(
756 "tekken.json — Voxtral Tekken tokenizer".to_string(),
757 ));
758 }
759 let fallback_repo = match model_type {
760 ModelType::Qwen3Tts => "Qwen/Qwen2.5-0.5B",
761 ModelType::VibeVoice => "Qwen/Qwen2.5-1.5B",
762 ModelType::VibeVoiceRealtime => "Qwen/Qwen2.5-0.5B",
763 _ => "Qwen/Qwen2.5-0.5B",
764 };
765 info!(
766 "tokenizer.json not in {}; falling back to {}",
767 model_id, fallback_repo
768 );
769 self.tokenizer = Some(ModelAsset::from_path(download(
770 fallback_repo,
771 "tokenizer.json",
772 )?));
773 }
774 }
775 }
776
777 if self.generation_config.is_none() {
779 if let Ok(p) = download(model_id, "generation_config.json") {
780 self.generation_config = Some(ModelAsset::from_path(p));
781 }
782 }
783
784 if self.preprocessor_config.is_none() {
785 if let Ok(p) = download(model_id, "preprocessor_config.json") {
786 self.preprocessor_config = Some(ModelAsset::from_path(p));
787 }
788 }
789
790 if self.weights.is_empty() {
792 self.download_weights_from_hf(model_id, bearer_token)?;
793 }
794
795 match model_type {
797 ModelType::Kokoro => {
798 self.download_kokoro_extras(model_id, bearer_token)?;
799 }
800 ModelType::OmniVoice => {
801 self.download_omnivoice_extras(model_id, bearer_token)?;
802 }
803 ModelType::Voxtral => {
804 self.download_voxtral_extras(model_id, bearer_token)?;
805 }
806 ModelType::Qwen3Tts => {
807 self.download_qwen3tts_extras(bearer_token)?;
808 }
809 ModelType::VibeVoice | ModelType::VibeVoiceRealtime => {
810 self.download_vibevoice_extras(model_id, bearer_token)?;
811 }
812 }
813
814 Ok(())
815 }
816
817 #[cfg(feature = "download")]
819 fn download_weights_from_hf(
820 &mut self,
821 model_id: &str,
822 bearer_token: Option<&str>,
823 ) -> Result<(), TtsError> {
824 use crate::download::download_file_with_token;
825
826 let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
827
828 if let Ok(p) = download(model_id, "model.safetensors") {
830 self.weights.push(ModelAsset::from_path(p));
831 return Ok(());
832 }
833
834 if let Ok(p) = download(model_id, "consolidated.safetensors") {
836 self.weights.push(ModelAsset::from_path(p));
837 return Ok(());
838 }
839
840 for pth_name in &["kokoro-v1_0.pth", "kokoro-v1_1-zh.pth", "model.pth"] {
842 if let Ok(p) = download(model_id, pth_name) {
843 self.weights.push(ModelAsset::from_path(p));
844 return Ok(());
845 }
846 }
847
848 let index_path = download(model_id, "model.safetensors.index.json")?;
850 let index_content = std::fs::read_to_string(&index_path)?;
851 let index: serde_json::Value = serde_json::from_str(&index_content)?;
852
853 if let Some(weight_map) = index.get("weight_map").and_then(|v| v.as_object()) {
854 let mut shard_names: Vec<String> = weight_map
855 .values()
856 .filter_map(|v| v.as_str().map(String::from))
857 .collect();
858 shard_names.sort();
859 shard_names.dedup();
860
861 for shard_name in &shard_names {
862 info!("Downloading shard: {}", shard_name);
863 let p = download(model_id, shard_name)?;
864 self.weights.push(ModelAsset::from_path(p));
865 }
866 }
867
868 Ok(())
869 }
870
871 #[cfg(feature = "download")]
873 fn download_kokoro_extras(
874 &mut self,
875 model_id: &str,
876 bearer_token: Option<&str>,
877 ) -> Result<(), TtsError> {
878 use crate::download::download_file_with_token;
879
880 let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
881
882 if self.voices_dir.is_none() {
883 if let Ok(voice_path) = download(model_id, "voices/af_heart.pt") {
885 if let Some(parent) = voice_path.parent() {
886 self.voices_dir = Some(ModelAssetDir::from_path(parent.to_path_buf()));
887 }
888 }
889 }
890
891 Ok(())
892 }
893
894 #[cfg(feature = "download")]
896 fn download_qwen3tts_extras(&mut self, bearer_token: Option<&str>) -> Result<(), TtsError> {
897 use crate::download::download_file_with_token;
898
899 let tokenizer_repo = "Qwen/Qwen3-TTS-Tokenizer-12Hz";
900 let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
901
902 if self.speech_tokenizer_config.is_none() {
903 info!(
904 "Downloading speech tokenizer config from {}",
905 tokenizer_repo
906 );
907 if let Ok(p) = download(tokenizer_repo, "config.json") {
908 self.speech_tokenizer_config = Some(ModelAsset::from_path(p));
909 }
910 }
911
912 if self.speech_tokenizer_weights.is_empty() {
913 info!(
914 "Downloading speech tokenizer weights from {}",
915 tokenizer_repo
916 );
917 if let Ok(p) = download(tokenizer_repo, "model.safetensors") {
918 self.speech_tokenizer_weights.push(ModelAsset::from_path(p));
919 } else if let Ok(index_path) = download(tokenizer_repo, "model.safetensors.index.json")
920 {
921 if let Ok(content) = std::fs::read_to_string(&index_path) {
922 if let Ok(index) = serde_json::from_str::<serde_json::Value>(&content) {
923 if let Some(weight_map) =
924 index.get("weight_map").and_then(|v| v.as_object())
925 {
926 let mut shard_names: Vec<String> = weight_map
927 .values()
928 .filter_map(|v| v.as_str().map(String::from))
929 .collect();
930 shard_names.sort();
931 shard_names.dedup();
932
933 for shard_name in &shard_names {
934 if let Ok(p) = download(tokenizer_repo, shard_name) {
935 self.speech_tokenizer_weights.push(ModelAsset::from_path(p));
936 }
937 }
938 }
939 }
940 }
941 }
942 }
943
944 Ok(())
945 }
946
947 #[cfg(feature = "download")]
948 fn download_vibevoice_extras(
949 &mut self,
950 model_id: &str,
951 bearer_token: Option<&str>,
952 ) -> Result<(), TtsError> {
953 use crate::download::download_file_with_token;
954
955 let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
956
957 if self.preprocessor_config.is_none() {
958 if let Ok(p) = download(model_id, "preprocessor_config.json") {
959 self.preprocessor_config = Some(ModelAsset::from_path(p));
960 }
961 }
962
963 Ok(())
964 }
965
966 #[cfg(feature = "download")]
968 fn download_omnivoice_extras(
969 &mut self,
970 model_id: &str,
971 bearer_token: Option<&str>,
972 ) -> Result<(), TtsError> {
973 use crate::download::download_file_with_token;
974
975 let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
976
977 if self.speech_tokenizer_config.is_none() {
978 if let Ok(p) = download(model_id, "audio_tokenizer/config.json") {
979 self.speech_tokenizer_config = Some(ModelAsset::from_path(p));
980 }
981 }
982
983 if self.speech_tokenizer_weights.is_empty() {
984 if let Ok(p) = download(model_id, "audio_tokenizer/model.safetensors") {
985 self.speech_tokenizer_weights.push(ModelAsset::from_path(p));
986 } else if let Ok(index_path) =
987 download(model_id, "audio_tokenizer/model.safetensors.index.json")
988 {
989 if let Ok(content) = std::fs::read_to_string(&index_path) {
990 if let Ok(index) = serde_json::from_str::<serde_json::Value>(&content) {
991 if let Some(weight_map) =
992 index.get("weight_map").and_then(|v| v.as_object())
993 {
994 let mut shard_names: Vec<String> = weight_map
995 .values()
996 .filter_map(|v| v.as_str().map(String::from))
997 .collect();
998 shard_names.sort();
999 shard_names.dedup();
1000
1001 for shard_name in &shard_names {
1002 let shard_path = format!("audio_tokenizer/{}", shard_name);
1003 if let Ok(p) = download(model_id, &shard_path) {
1004 self.speech_tokenizer_weights.push(ModelAsset::from_path(p));
1005 }
1006 }
1007 }
1008 }
1009 }
1010 }
1011 }
1012
1013 Ok(())
1014 }
1015
1016 #[cfg(feature = "download")]
1018 fn download_voxtral_extras(
1019 &mut self,
1020 model_id: &str,
1021 bearer_token: Option<&str>,
1022 ) -> Result<(), TtsError> {
1023 use crate::download::download_file_with_token;
1024
1025 let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
1026
1027 if self.voices_dir.is_some() {
1028 return Ok(());
1029 }
1030
1031 let config_path = self.config.as_ref().ok_or_else(|| {
1032 TtsError::FileMissing("params.json — Voxtral model configuration".to_string())
1033 })?;
1034 let content = config_path.read_bytes()?;
1035 let config: serde_json::Value = serde_json::from_slice(&content)?;
1036 let voices = config
1037 .get("multimodal")
1038 .and_then(|v| v.get("audio_tokenizer_args"))
1039 .and_then(|v| v.get("voice"))
1040 .and_then(|v| v.as_object())
1041 .ok_or_else(|| {
1042 TtsError::ConfigError(
1043 "params.json is missing multimodal.audio_tokenizer_args.voice".to_string(),
1044 )
1045 })?;
1046
1047 let mut discovered_dir: Option<ModelAssetDir> = None;
1048 for voice_name in voices.keys() {
1049 let path = download(model_id, &format!("voice_embedding/{voice_name}.pt"))?;
1050 if discovered_dir.is_none() {
1051 discovered_dir = path
1052 .parent()
1053 .map(|parent| ModelAssetDir::from_path(parent.to_path_buf()));
1054 }
1055 }
1056
1057 self.voices_dir = discovered_dir;
1058 Ok(())
1059 }
1060
1061 pub fn validate(&self, model_type: ModelType) -> Result<(), TtsError> {
1063 if model_type == ModelType::Voxtral {
1064 if self.config.is_none() {
1065 return Err(TtsError::FileMissing(
1066 "params.json — Voxtral model configuration".to_string(),
1067 ));
1068 }
1069 if self.tokenizer.is_none() {
1070 return Err(TtsError::FileMissing(
1071 "tekken.json — Voxtral Tekken tokenizer".to_string(),
1072 ));
1073 }
1074 if self.weights.is_empty() {
1075 return Err(TtsError::FileMissing(
1076 "consolidated.safetensors — Voxtral model weights".to_string(),
1077 ));
1078 }
1079 if self.voices_dir.is_none() {
1080 return Err(TtsError::FileMissing(
1081 "voice_embedding/ — Voxtral preset voice embeddings".to_string(),
1082 ));
1083 }
1084 return Ok(());
1085 }
1086
1087 if self.config.is_none() {
1088 return Err(TtsError::FileMissing(
1089 "config.json — model architecture configuration".to_string(),
1090 ));
1091 }
1092 if model_type != ModelType::Kokoro && self.tokenizer.is_none() {
1094 return Err(TtsError::FileMissing(
1095 "tokenizer.json — BPE text tokenizer".to_string(),
1096 ));
1097 }
1098 if self.weights.is_empty() {
1099 return Err(TtsError::FileMissing(
1100 "model weight files (.safetensors or .pth)".to_string(),
1101 ));
1102 }
1103
1104 match model_type {
1105 ModelType::OmniVoice => {
1106 if self.speech_tokenizer_config.is_none() {
1107 return Err(TtsError::FileMissing(
1108 "audio tokenizer config (audio_tokenizer/config.json) \
1109 — configures OmniVoice's codec decoder"
1110 .to_string(),
1111 ));
1112 }
1113 if self.speech_tokenizer_weights.is_empty() {
1114 return Err(TtsError::FileMissing(
1115 "audio tokenizer weights (audio_tokenizer/model.safetensors) \
1116 — converts OmniVoice codec tokens to audio waveform"
1117 .to_string(),
1118 ));
1119 }
1120 }
1121 ModelType::Qwen3Tts => {
1122 if self.speech_tokenizer_weights.is_empty() {
1123 return Err(TtsError::FileMissing(
1124 "speech tokenizer weights (Qwen3-TTS-Tokenizer-12Hz model.safetensors) \
1125 — converts codec tokens to audio waveform"
1126 .to_string(),
1127 ));
1128 }
1129 }
1130 ModelType::Kokoro => {
1131 }
1133 ModelType::VibeVoice | ModelType::VibeVoiceRealtime => {}
1134 ModelType::Voxtral => unreachable!(),
1135 }
1136
1137 Ok(())
1138 }
1139
1140 pub fn missing_files(&self, model_type: ModelType) -> Vec<&'static str> {
1142 if model_type == ModelType::Voxtral {
1143 let mut missing = Vec::new();
1144 if self.config.is_none() {
1145 missing.push("params.json");
1146 }
1147 if self.tokenizer.is_none() {
1148 missing.push("tekken.json");
1149 }
1150 if self.weights.is_empty() {
1151 missing.push("consolidated.safetensors");
1152 }
1153 if self.voices_dir.is_none() {
1154 missing.push("voice_embedding");
1155 }
1156 return missing;
1157 }
1158
1159 let mut missing = Vec::new();
1160
1161 if self.config.is_none() {
1162 missing.push("config.json");
1163 }
1164 if model_type != ModelType::Kokoro && self.tokenizer.is_none() {
1165 missing.push("tokenizer.json");
1166 }
1167 if self.weights.is_empty() {
1168 missing.push("model weight files");
1169 }
1170 if model_type == ModelType::OmniVoice && self.speech_tokenizer_config.is_none() {
1171 missing.push("audio tokenizer config");
1172 }
1173 if model_type == ModelType::OmniVoice && self.speech_tokenizer_weights.is_empty() {
1174 missing.push("audio tokenizer weights");
1175 }
1176 if model_type == ModelType::Qwen3Tts && self.speech_tokenizer_weights.is_empty() {
1177 missing.push("speech tokenizer weights");
1178 }
1179
1180 missing
1181 }
1182}
1183
1184#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1190pub enum DType {
1191 F32,
1193 F16,
1195 #[default]
1197 BF16,
1198}
1199
1200impl DType {
1201 pub fn to_candle(self) -> candle_core::DType {
1203 match self {
1204 Self::F32 => candle_core::DType::F32,
1205 Self::F16 => candle_core::DType::F16,
1206 Self::BF16 => candle_core::DType::BF16,
1207 }
1208 }
1209
1210 pub fn label(self) -> &'static str {
1212 match self {
1213 Self::F32 => "f32",
1214 Self::F16 => "f16",
1215 Self::BF16 => "bf16",
1216 }
1217 }
1218}
1219
1220#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1225pub struct RuntimeChoice {
1226 pub device: DeviceSelection,
1228 pub dtype: DType,
1230}
1231
1232impl RuntimeChoice {
1233 pub fn label(&self) -> String {
1235 format!("{} ({})", self.device.label(), self.dtype.label())
1236 }
1237}
1238
1239pub fn preferred_runtime_choices(model_type: ModelType) -> Vec<RuntimeChoice> {
1241 DeviceSelection::available_runtime_candidates()
1242 .into_iter()
1243 .map(|device| RuntimeChoice {
1244 device,
1245 dtype: preferred_dtype_for(model_type, device),
1246 })
1247 .collect()
1248}
1249
1250pub fn preferred_runtime_choice(model_type: ModelType) -> RuntimeChoice {
1252 preferred_runtime_choices(model_type)
1253 .into_iter()
1254 .next()
1255 .unwrap_or(RuntimeChoice {
1256 device: DeviceSelection::Cpu,
1257 dtype: DType::F32,
1258 })
1259}
1260
1261fn preferred_dtype_for(model_type: ModelType, device: DeviceSelection) -> DType {
1262 match model_type {
1263 ModelType::OmniVoice => match device {
1264 DeviceSelection::Cpu => DType::F32,
1265 DeviceSelection::Cuda(_) => DType::BF16,
1266 DeviceSelection::Metal(_) => DType::F32,
1267 DeviceSelection::Auto => DType::BF16,
1268 },
1269 ModelType::Kokoro => match device {
1270 DeviceSelection::Cpu => DType::F32,
1271 DeviceSelection::Cuda(_) => DType::BF16,
1272 DeviceSelection::Metal(_) => DType::F32,
1273 DeviceSelection::Auto => DType::BF16,
1274 },
1275 ModelType::Qwen3Tts => match device {
1276 DeviceSelection::Cpu => DType::F32,
1277 DeviceSelection::Cuda(_) => DType::BF16,
1278 DeviceSelection::Metal(_) => DType::BF16,
1279 DeviceSelection::Auto => DType::BF16,
1280 },
1281 ModelType::VibeVoice | ModelType::VibeVoiceRealtime => match device {
1282 DeviceSelection::Cpu => DType::F32,
1283 DeviceSelection::Cuda(_) => DType::BF16,
1284 DeviceSelection::Metal(_) => DType::F32,
1285 DeviceSelection::Auto => DType::BF16,
1286 },
1287 ModelType::Voxtral => match device {
1288 DeviceSelection::Cpu => DType::F32,
1289 DeviceSelection::Cuda(_) => DType::BF16,
1290 DeviceSelection::Metal(_) => DType::F32,
1291 DeviceSelection::Auto => DType::BF16,
1292 },
1293 }
1294}
1295
1296#[derive(Debug, Clone)]
1343pub struct TtsConfig {
1344 pub model_type: ModelType,
1346
1347 pub model_path: Option<String>,
1353
1354 pub hf_model_id: Option<String>,
1357
1358 pub runtime_command: Option<String>,
1360
1361 pub runtime_endpoint: Option<String>,
1363
1364 pub bearer_token: Option<String>,
1366
1367 pub device: DeviceSelection,
1369
1370 pub dtype: DType,
1372
1373 pub files: ModelFiles,
1375
1376 pub asset_bundle: ModelAssetBundle,
1378}
1379
1380impl TtsConfig {
1381 pub fn new(model_type: ModelType) -> Self {
1383 Self {
1384 model_type,
1385 model_path: None,
1386 hf_model_id: None,
1387 runtime_command: None,
1388 runtime_endpoint: None,
1389 bearer_token: None,
1390 device: DeviceSelection::Auto,
1391 dtype: DType::default(),
1392 files: ModelFiles::default(),
1393 asset_bundle: ModelAssetBundle::default(),
1394 }
1395 }
1396
1397 pub fn with_model_path(mut self, path: impl Into<String>) -> Self {
1404 self.model_path = Some(path.into());
1405 self
1406 }
1407
1408 pub fn with_asset_bundle(mut self, bundle: ModelAssetBundle) -> Self {
1410 self.asset_bundle = bundle;
1411 self
1412 }
1413
1414 pub fn with_asset_bytes(
1417 mut self,
1418 relative_path: impl Into<String>,
1419 bytes: impl Into<Vec<u8>>,
1420 ) -> Self {
1421 self.asset_bundle.insert_bytes(relative_path, bytes);
1422 self
1423 }
1424
1425 pub fn with_hf_model_id(mut self, id: impl Into<String>) -> Self {
1429 self.hf_model_id = Some(id.into());
1430 self
1431 }
1432
1433 pub fn with_runtime_command(mut self, command: impl Into<String>) -> Self {
1435 self.runtime_command = Some(command.into());
1436 self
1437 }
1438
1439 pub fn with_runtime_endpoint(mut self, endpoint: impl Into<String>) -> Self {
1441 self.runtime_endpoint = Some(endpoint.into());
1442 self
1443 }
1444
1445 pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
1447 self.bearer_token = Some(token.into());
1448 self
1449 }
1450
1451 pub fn with_device(mut self, device: DeviceSelection) -> Self {
1455 self.device = device;
1456 self
1457 }
1458
1459 pub fn with_dtype(mut self, dtype: DType) -> Self {
1461 self.dtype = dtype;
1462 self
1463 }
1464
1465 pub fn with_preferred_runtime(mut self) -> Self {
1470 let runtime = preferred_runtime_choice(self.model_type);
1471 self.device = runtime.device;
1472 self.dtype = runtime.dtype;
1473 self
1474 }
1475
1476 pub fn with_config_file(mut self, path: impl Into<PathBuf>) -> Self {
1484 self.files.config = Some(ModelAsset::from_path(path.into()));
1485 self
1486 }
1487
1488 pub fn with_config_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1490 self.files.config = Some(ModelAsset::from_bytes("config.json", bytes));
1491 self
1492 }
1493
1494 pub fn with_tokenizer_file(mut self, path: impl Into<PathBuf>) -> Self {
1500 self.files.tokenizer = Some(ModelAsset::from_path(path.into()));
1501 self
1502 }
1503
1504 pub fn with_tokenizer_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1506 self.files.tokenizer = Some(ModelAsset::from_bytes("tokenizer.json", bytes));
1507 self
1508 }
1509
1510 pub fn with_weight_file(mut self, path: impl Into<PathBuf>) -> Self {
1521 self.files.weights.push(ModelAsset::from_path(path.into()));
1522 self
1523 }
1524
1525 pub fn with_weight_bytes(
1527 mut self,
1528 file_name: impl Into<String>,
1529 bytes: impl Into<Vec<u8>>,
1530 ) -> Self {
1531 self.files
1532 .weights
1533 .push(ModelAsset::from_bytes(file_name.into(), bytes));
1534 self
1535 }
1536
1537 pub fn with_weight_files(mut self, paths: Vec<PathBuf>) -> Self {
1539 self.files.weights = paths.into_iter().map(ModelAsset::from_path).collect();
1540 self
1541 }
1542
1543 pub fn with_voices_dir(mut self, path: impl Into<PathBuf>) -> Self {
1548 self.files.voices_dir = Some(ModelAssetDir::from_path(path.into()));
1549 self
1550 }
1551
1552 pub fn with_voice_bytes(
1554 mut self,
1555 voice_name: impl Into<String>,
1556 bytes: impl Into<Vec<u8>>,
1557 ) -> Self {
1558 let voice_file = format!("{}.pt", voice_name.into());
1559 match self.files.voices_dir.take() {
1560 Some(ModelAssetDir::Bytes(mut entries)) => {
1561 entries.insert(voice_file, Arc::from(bytes.into()));
1562 self.files.voices_dir = Some(ModelAssetDir::from_bytes(entries));
1563 }
1564 Some(ModelAssetDir::Path(path)) => {
1565 self.files.voices_dir = Some(ModelAssetDir::Path(path));
1566 self.asset_bundle
1567 .insert_bytes(format!("voices/{voice_file}"), bytes);
1568 }
1569 None => {
1570 let mut entries = BTreeMap::new();
1571 entries.insert(voice_file, Arc::from(bytes.into()));
1572 self.files.voices_dir = Some(ModelAssetDir::from_bytes(entries));
1573 }
1574 }
1575 self
1576 }
1577
1578 pub fn with_speech_tokenizer_weight_file(mut self, path: impl Into<PathBuf>) -> Self {
1584 self.files
1585 .speech_tokenizer_weights
1586 .push(ModelAsset::from_path(path.into()));
1587 self
1588 }
1589
1590 pub fn with_speech_tokenizer_weight_bytes(
1592 mut self,
1593 file_name: impl Into<String>,
1594 bytes: impl Into<Vec<u8>>,
1595 ) -> Self {
1596 self.files
1597 .speech_tokenizer_weights
1598 .push(ModelAsset::from_bytes(file_name.into(), bytes));
1599 self
1600 }
1601
1602 pub fn with_speech_tokenizer_weight_files(mut self, paths: Vec<PathBuf>) -> Self {
1604 self.files.speech_tokenizer_weights =
1605 paths.into_iter().map(ModelAsset::from_path).collect();
1606 self
1607 }
1608
1609 pub fn with_speech_tokenizer_config_file(mut self, path: impl Into<PathBuf>) -> Self {
1614 self.files.speech_tokenizer_config = Some(ModelAsset::from_path(path.into()));
1615 self
1616 }
1617
1618 pub fn with_speech_tokenizer_config_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1620 self.files.speech_tokenizer_config = Some(ModelAsset::from_bytes(
1621 "speech_tokenizer/config.json",
1622 bytes,
1623 ));
1624 self
1625 }
1626
1627 pub fn with_generation_config_file(mut self, path: impl Into<PathBuf>) -> Self {
1632 self.files.generation_config = Some(ModelAsset::from_path(path.into()));
1633 self
1634 }
1635
1636 pub fn with_generation_config_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1638 self.files.generation_config =
1639 Some(ModelAsset::from_bytes("generation_config.json", bytes));
1640 self
1641 }
1642
1643 pub fn with_preprocessor_config_file(mut self, path: impl Into<PathBuf>) -> Self {
1648 self.files.preprocessor_config = Some(ModelAsset::from_path(path.into()));
1649 self
1650 }
1651
1652 pub fn with_preprocessor_config_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1654 self.files.preprocessor_config =
1655 Some(ModelAsset::from_bytes("preprocessor_config.json", bytes));
1656 self
1657 }
1658
1659 pub fn resolve_files(&self) -> Result<ModelFiles, TtsError> {
1671 let mut files = self.files.clone();
1672
1673 if !self.asset_bundle.is_empty() {
1674 files.fill_from_asset_bundle(&self.asset_bundle);
1675 }
1676
1677 if let Some(ref dir) = self.model_path {
1679 files.fill_from_directory(Path::new(dir));
1680 }
1681
1682 #[cfg(feature = "download")]
1684 {
1685 if !files.missing_files(self.model_type).is_empty() {
1686 let hf_id = self.effective_hf_model_id();
1687 info!("Downloading missing files from HuggingFace: {}", hf_id);
1688 files.fill_from_hf(hf_id, self.model_type, self.bearer_token.as_deref())?;
1689 }
1690 }
1691
1692 files.validate(self.model_type)?;
1694
1695 Ok(files)
1696 }
1697
1698 pub fn default_hf_model_id(&self) -> &str {
1700 match self.model_type {
1701 ModelType::Kokoro => "hexgrad/Kokoro-82M",
1702 ModelType::OmniVoice => "k2-fsa/OmniVoice",
1703 ModelType::Qwen3Tts => "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
1704 ModelType::VibeVoice => "microsoft/VibeVoice-1.5B",
1705 ModelType::VibeVoiceRealtime => "microsoft/VibeVoice-Realtime-0.5B",
1706 ModelType::Voxtral => "mistralai/Voxtral-4B-TTS-2603",
1707 }
1708 }
1709
1710 pub fn effective_hf_model_id(&self) -> &str {
1712 self.hf_model_id
1713 .as_deref()
1714 .unwrap_or_else(|| self.default_hf_model_id())
1715 }
1716
1717 pub fn effective_model_ref(&self) -> &str {
1719 self.model_path
1720 .as_deref()
1721 .unwrap_or_else(|| self.effective_hf_model_id())
1722 }
1723
1724 pub fn default_runtime_command(&self) -> Option<&str> {
1726 match self.model_type {
1727 ModelType::Voxtral => Some("python3"),
1728 ModelType::Kokoro
1729 | ModelType::OmniVoice
1730 | ModelType::Qwen3Tts
1731 | ModelType::VibeVoice
1732 | ModelType::VibeVoiceRealtime => None,
1733 }
1734 }
1735
1736 pub fn effective_runtime_command(&self) -> Option<&str> {
1738 self.runtime_command
1739 .as_deref()
1740 .or_else(|| self.default_runtime_command())
1741 }
1742
1743 pub fn default_runtime_endpoint(&self) -> Option<&str> {
1745 match self.model_type {
1746 ModelType::Kokoro
1747 | ModelType::OmniVoice
1748 | ModelType::Qwen3Tts
1749 | ModelType::VibeVoice
1750 | ModelType::VibeVoiceRealtime
1751 | ModelType::Voxtral => None,
1752 }
1753 }
1754
1755 pub fn effective_runtime_endpoint(&self) -> Option<&str> {
1757 self.runtime_endpoint
1758 .as_deref()
1759 .or_else(|| self.default_runtime_endpoint())
1760 }
1761
1762 pub fn effective_bearer_token(&self) -> &str {
1764 self.bearer_token.as_deref().unwrap_or("EMPTY")
1765 }
1766}
1767
1768#[cfg(test)]
1769mod tests {
1770 use super::*;
1771
1772 #[test]
1773 fn test_dtype_labels_are_stable() {
1774 assert_eq!(DType::F32.label(), "f32");
1775 assert_eq!(DType::F16.label(), "f16");
1776 assert_eq!(DType::BF16.label(), "bf16");
1777 }
1778
1779 #[test]
1780 fn test_kokoro_metal_prefers_f32() {
1781 assert_eq!(
1782 preferred_dtype_for(ModelType::Kokoro, DeviceSelection::Metal(0)),
1783 DType::F32
1784 );
1785 }
1786
1787 #[test]
1788 fn test_qwen3_metal_prefers_bf16() {
1789 assert_eq!(
1790 preferred_dtype_for(ModelType::Qwen3Tts, DeviceSelection::Metal(0)),
1791 DType::BF16
1792 );
1793 }
1794
1795 #[test]
1796 fn test_omnivoice_metal_prefers_f32() {
1797 let choice = RuntimeChoice {
1798 device: DeviceSelection::Metal(0),
1799 dtype: preferred_dtype_for(ModelType::OmniVoice, DeviceSelection::Metal(0)),
1800 };
1801 assert_eq!(choice.label(), "metal:0 (f32)");
1802 }
1803
1804 #[test]
1805 fn test_with_preferred_runtime_applies_choice() {
1806 let expected = preferred_runtime_choice(ModelType::VibeVoice);
1807 let config = TtsConfig::new(ModelType::VibeVoice).with_preferred_runtime();
1808 assert_eq!(config.device, expected.device);
1809 assert_eq!(config.dtype, expected.dtype);
1810 }
1811
1812 #[test]
1813 fn test_resolve_files_from_in_memory_omnivoice_assets() {
1814 let bundle = ModelAssetBundle::new()
1815 .with_bytes("config.json", vec![1])
1816 .with_bytes("tokenizer.json", vec![2])
1817 .with_bytes("model.safetensors", vec![3])
1818 .with_bytes("audio_tokenizer/config.json", vec![4])
1819 .with_bytes("audio_tokenizer/model.safetensors", vec![5]);
1820
1821 let files = TtsConfig::new(ModelType::OmniVoice)
1822 .with_asset_bundle(bundle)
1823 .resolve_files()
1824 .unwrap();
1825
1826 assert!(matches!(files.config, Some(ModelAsset::Bytes { .. })));
1827 assert!(matches!(files.tokenizer, Some(ModelAsset::Bytes { .. })));
1828 assert_eq!(files.weights.len(), 1);
1829 assert_eq!(files.speech_tokenizer_weights.len(), 1);
1830 }
1831
1832 #[test]
1833 fn test_with_voice_bytes_creates_in_memory_voice_dir() {
1834 let config = TtsConfig::new(ModelType::Kokoro).with_voice_bytes("af_heart", vec![1, 2]);
1835 let voices_dir = config.files.voices_dir.as_ref().unwrap();
1836 assert_eq!(voices_dir.file_names().unwrap(), vec!["af_heart.pt"]);
1837 }
1838
1839 #[test]
1840 fn test_model_asset_manifest_is_available() {
1841 let requirements = ModelType::Voxtral.asset_requirements();
1842 assert!(!requirements.is_empty());
1843 assert!(requirements
1844 .iter()
1845 .any(|entry| entry.pattern == "voice_embedding/*.pt"));
1846 }
1847}