Skip to main content

any_tts/
config.rs

1//! Configuration types for TTS models.
2//!
3//! Provides a builder-pattern API for specifying model files individually
4//! (for custom download managers) or by directory, with HuggingFace Hub
5//! download as a fallback.
6
7use std::collections::{BTreeMap, HashMap};
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10
11use candle_nn::VarBuilder;
12use tracing::info;
13
14use crate::device::DeviceSelection;
15use crate::error::TtsError;
16use crate::models::ModelType;
17
18fn normalize_asset_path(path: impl AsRef<str>) -> String {
19    path.as_ref()
20        .replace('\\', "/")
21        .trim_start_matches("./")
22        .trim_start_matches('/')
23        .to_string()
24}
25
26/// A single model asset that can come from disk or in-memory bytes.
27#[derive(Debug, Clone)]
28pub enum ModelAsset {
29    Path(PathBuf),
30    Bytes { name: String, data: Arc<[u8]> },
31}
32
33impl ModelAsset {
34    pub fn from_path(path: impl Into<PathBuf>) -> Self {
35        Self::Path(path.into())
36    }
37
38    pub fn from_bytes(name: impl Into<String>, bytes: impl Into<Vec<u8>>) -> Self {
39        Self::Bytes {
40            name: normalize_asset_path(name.into()),
41            data: Arc::from(bytes.into()),
42        }
43    }
44
45    pub fn as_path(&self) -> Option<&Path> {
46        match self {
47            Self::Path(path) => Some(path),
48            Self::Bytes { .. } => None,
49        }
50    }
51
52    pub fn file_name(&self) -> Option<&str> {
53        match self {
54            Self::Path(path) => path.file_name().and_then(|name| name.to_str()),
55            Self::Bytes { name, .. } => {
56                Path::new(name).file_name().and_then(|value| value.to_str())
57            }
58        }
59    }
60
61    pub fn extension(&self) -> Option<&str> {
62        match self {
63            Self::Path(path) => path.extension().and_then(|ext| ext.to_str()),
64            Self::Bytes { name, .. } => Path::new(name).extension().and_then(|ext| ext.to_str()),
65        }
66    }
67
68    pub fn display_name(&self) -> String {
69        match self {
70            Self::Path(path) => path.display().to_string(),
71            Self::Bytes { name, .. } => name.clone(),
72        }
73    }
74
75    pub fn read_bytes(&self) -> Result<Arc<[u8]>, TtsError> {
76        match self {
77            Self::Path(path) => std::fs::read(path).map(Arc::from).map_err(TtsError::from),
78            Self::Bytes { data, .. } => Ok(data.clone()),
79        }
80    }
81}
82
83/// A logical model asset directory, backed either by the filesystem or by in-memory bytes.
84#[derive(Debug, Clone)]
85pub enum ModelAssetDir {
86    Path(PathBuf),
87    Bytes(BTreeMap<String, Arc<[u8]>>),
88}
89
90impl ModelAssetDir {
91    pub fn from_path(path: impl Into<PathBuf>) -> Self {
92        Self::Path(path.into())
93    }
94
95    pub fn from_bytes(entries: BTreeMap<String, Arc<[u8]>>) -> Self {
96        Self::Bytes(entries)
97    }
98
99    pub fn load_file(&self, name: &str) -> Result<ModelAsset, TtsError> {
100        match self {
101            Self::Path(path) => {
102                let full_path = path.join(name);
103                if !full_path.exists() {
104                    return Err(TtsError::FileMissing(format!(
105                        "{} in {}",
106                        name,
107                        path.display()
108                    )));
109                }
110                Ok(ModelAsset::from_path(full_path))
111            }
112            Self::Bytes(entries) => entries
113                .get(name)
114                .cloned()
115                .map(|data| ModelAsset::Bytes {
116                    name: name.to_string(),
117                    data,
118                })
119                .ok_or_else(|| TtsError::FileMissing(name.to_string())),
120        }
121    }
122
123    pub fn file_names(&self) -> Result<Vec<String>, TtsError> {
124        match self {
125            Self::Path(path) => {
126                let mut names = Vec::new();
127                for entry in std::fs::read_dir(path)? {
128                    let entry = entry?;
129                    let Some(name) = entry.file_name().to_str().map(str::to_string) else {
130                        continue;
131                    };
132                    names.push(name);
133                }
134                names.sort();
135                Ok(names)
136            }
137            Self::Bytes(entries) => Ok(entries.keys().cloned().collect()),
138        }
139    }
140}
141
142/// A named collection of relative-path assets for byte-first loading.
143#[derive(Debug, Clone, Default)]
144pub struct ModelAssetBundle {
145    entries: BTreeMap<String, Arc<[u8]>>,
146}
147
148impl ModelAssetBundle {
149    pub fn new() -> Self {
150        Self::default()
151    }
152
153    pub fn insert_bytes(
154        &mut self,
155        relative_path: impl Into<String>,
156        bytes: impl Into<Vec<u8>>,
157    ) -> &mut Self {
158        let relative_path = normalize_asset_path(relative_path.into());
159        self.entries.insert(relative_path, Arc::from(bytes.into()));
160        self
161    }
162
163    pub fn with_bytes(
164        mut self,
165        relative_path: impl Into<String>,
166        bytes: impl Into<Vec<u8>>,
167    ) -> Self {
168        self.insert_bytes(relative_path, bytes);
169        self
170    }
171
172    pub fn is_empty(&self) -> bool {
173        self.entries.is_empty()
174    }
175
176    fn get(&self, relative_path: &str) -> Option<ModelAsset> {
177        let relative_path = normalize_asset_path(relative_path);
178        self.entries
179            .get(&relative_path)
180            .cloned()
181            .map(|data| ModelAsset::Bytes {
182                name: relative_path,
183                data,
184            })
185    }
186
187    fn collect_directory(&self, prefix: &str) -> Option<ModelAssetDir> {
188        let prefix = normalize_asset_path(prefix);
189        let prefix = if prefix.ends_with('/') {
190            prefix
191        } else {
192            format!("{prefix}/")
193        };
194
195        let mut entries = BTreeMap::new();
196        for (path, data) in &self.entries {
197            let Some(rest) = path.strip_prefix(&prefix) else {
198                continue;
199            };
200            if rest.is_empty() || rest.contains('/') {
201                continue;
202            }
203            entries.insert(rest.to_string(), data.clone());
204        }
205
206        if entries.is_empty() {
207            None
208        } else {
209            Some(ModelAssetDir::from_bytes(entries))
210        }
211    }
212
213    fn discover_sharded_weights(&self, prefix: &str) -> Vec<ModelAsset> {
214        let prefix = normalize_asset_path(prefix);
215        let prefix = if prefix.is_empty() {
216            String::new()
217        } else if prefix.ends_with('/') {
218            prefix
219        } else {
220            format!("{prefix}/")
221        };
222
223        let mut shards = self
224            .entries
225            .iter()
226            .filter_map(|(path, data)| {
227                let rest = if prefix.is_empty() {
228                    path.as_str()
229                } else {
230                    path.strip_prefix(&prefix)?
231                };
232                if rest.contains('/')
233                    || !rest.starts_with("model-")
234                    || !rest.ends_with(".safetensors")
235                {
236                    return None;
237                }
238                Some(ModelAsset::Bytes {
239                    name: path.clone(),
240                    data: data.clone(),
241                })
242            })
243            .collect::<Vec<_>>();
244        shards.sort_by_key(ModelAsset::display_name);
245        shards
246    }
247
248    fn discover_pth_weights(&self, prefix: &str) -> Vec<ModelAsset> {
249        let prefix = normalize_asset_path(prefix);
250        let prefix = if prefix.is_empty() {
251            String::new()
252        } else if prefix.ends_with('/') {
253            prefix
254        } else {
255            format!("{prefix}/")
256        };
257
258        let mut weights = self
259            .entries
260            .iter()
261            .filter_map(|(path, data)| {
262                let rest = if prefix.is_empty() {
263                    path.as_str()
264                } else {
265                    path.strip_prefix(&prefix)?
266                };
267                if rest.contains('/') || !rest.ends_with(".pth") {
268                    return None;
269                }
270                Some(ModelAsset::Bytes {
271                    name: path.clone(),
272                    data: data.clone(),
273                })
274            })
275            .collect::<Vec<_>>();
276        weights.sort_by_key(ModelAsset::display_name);
277        weights
278    }
279}
280
281// ---------------------------------------------------------------------------
282// ModelFiles — resolved model asset storage
283// ---------------------------------------------------------------------------
284
285/// Resolved model assets for loading.
286///
287/// Each model type requires a specific set of files. You can provide them
288/// individually using the builder methods on [`TtsConfig`], set
289/// [`TtsConfig::model_path`] to a directory that contains all of them, or
290/// rely on automatic HuggingFace Hub download (if the `download` feature
291/// is enabled).
292///
293/// ## File resolution order (per file)
294///
295/// 1. **Explicit path** — set via `with_*_file()` / `with_*_dir()` on
296///    [`TtsConfig`]. Use this when your project has its own download
297///    manager (e.g. flow-like hash-based local caching).
298/// 2. **Auto-discovery** — if `model_path` is set, the library looks for
299///    well-known filenames inside that directory.
300/// 3. **HuggingFace Hub download** — if the `download` feature is enabled
301///    and the file is still missing, it is fetched from the Hub. This is
302///    the convenient fallback for quick prototyping.
303#[derive(Debug, Clone, Default)]
304pub struct ModelFiles {
305    // ── Shared across all models ──────────────────────────────────────
306    /// Path to **`config.json`** — model architecture configuration.
307    ///
308    /// **Expected format:** JSON object describing the neural-network
309    /// hyperparameters (hidden size, number of layers, vocab size, …).
310    /// This is the standard HuggingFace `config.json` format.
311    /// Each backend stores its architecture metadata here, such as
312    /// transformer dimensions, tokenizer sizes, sample rates, or
313    /// auxiliary decoder configuration.
314    pub config: Option<ModelAsset>,
315
316    /// Path to **`tokenizer.json`** — BPE text tokenizer definition.
317    ///
318    /// **Expected format:** [HuggingFace Tokenizers](https://huggingface.co/docs/tokenizers)
319    /// self-contained JSON file. Contains the full vocabulary, merge rules,
320    /// special tokens, and pre/post-processing steps. No separate
321    /// `vocab.json` or `merges.txt` required when this file is present.
322    ///
323    /// Used by both models to convert input text into token IDs before
324    /// feeding them to the transformer backbone.
325    pub tokenizer: Option<ModelAsset>,
326
327    /// Paths to **model weight files** (`.safetensors`).
328    ///
329    /// **Expected format:** One or more [SafeTensors](https://huggingface.co/docs/safetensors)
330    /// files containing the neural-network parameters.
331    ///
332    /// * **Single file** — `model.safetensors` (for models < ~5 GB).
333    /// * **Sharded** — `model-00001-of-00004.safetensors`, … When
334    ///   sharded, the library also expects `model.safetensors.index.json`
335    ///   in the same directory (auto-discovered or downloaded).
336    /// * **Other formats** — some backends use `consolidated.safetensors`
337    ///   or `.pth` files instead of the standard filename.
338    pub weights: Vec<ModelAsset>,
339
340    // ── Voice asset directories ───────────────────────────────────────
341    /// Path to a **voice asset directory** for backends that ship preset voices.
342    ///
343    /// Supported layouts include:
344    ///
345    /// ```text
346    /// voices/                ← Kokoro preset voices (`*.pt`)
347    /// voice_embedding/       ← Voxtral preset voices (`*.pt`)
348    /// ```
349    ///
350    /// The exact file format depends on the backend.
351    pub voices_dir: Option<ModelAssetDir>,
352
353    // ── Qwen3-TTS / OmniVoice-specific ────────────────────────────────
354    /// Paths to the **speech/audio tokenizer decoder** weight files.
355    ///
356    /// **Expected format:** SafeTensors files for the auxiliary decoder used
357    /// by models that emit discrete audio codec tokens.
358    ///
359    /// Contains:
360    /// * Residual VQ codebooks (16 groups × 2048 codes × dim)
361    /// * Pre-conv + pre-transformer layers
362    /// * Upsampling layers (transposed convolutions + SnakeBeta)
363    /// * Final decoder convolution
364    ///
365    /// * **Qwen3-TTS** uses the separate
366    ///   `Qwen/Qwen3-TTS-Tokenizer-12Hz` repository.
367    /// * **OmniVoice** uses the `audio_tokenizer/` subdirectory inside the
368    ///   main model snapshot.
369    pub speech_tokenizer_weights: Vec<ModelAsset>,
370
371    /// Path to **`config.json`** of the speech/audio tokenizer.
372    ///
373    /// **Expected format:** JSON config for the speech tokenizer decoder
374    /// model, including codebook dimensions, upsampling ratios, and
375    /// activation parameters.
376    ///
377    /// If not provided, will be auto-discovered from a nested
378    /// `audio_tokenizer/` directory or downloaded from HuggingFace.
379    pub speech_tokenizer_config: Option<ModelAsset>,
380
381    /// Path to **`generation_config.json`** (optional).
382    ///
383    /// **Expected format:** Standard HuggingFace generation configuration
384    /// with fields like `max_new_tokens`, `top_p`, `temperature`,
385    /// `do_sample`, `repetition_penalty`, etc.
386    ///
387    /// If not provided, sensible per-model defaults are used.
388    pub generation_config: Option<ModelAsset>,
389
390    /// Path to **`preprocessor_config.json`** (optional).
391    ///
392    /// Used by backends such as VibeVoice that publish prompt-building and
393    /// audio-normalization defaults separately from `config.json`.
394    pub preprocessor_config: Option<ModelAsset>,
395}
396
397impl ModelFiles {
398    /// Scan a directory for well-known model files and fill any that are
399    /// still `None` / empty.
400    pub fn fill_from_directory(&mut self, dir: &Path) {
401        // config.json
402        if self.config.is_none() {
403            let p = dir.join("config.json");
404            if p.exists() {
405                info!("Auto-discovered config: {}", p.display());
406                self.config = Some(ModelAsset::from_path(p));
407            } else {
408                let p = dir.join("params.json");
409                if p.exists() {
410                    info!("Auto-discovered config: {}", p.display());
411                    self.config = Some(ModelAsset::from_path(p));
412                }
413            }
414        }
415
416        // tokenizer.json
417        if self.tokenizer.is_none() {
418            let p = dir.join("tokenizer.json");
419            if p.exists() {
420                info!("Auto-discovered tokenizer: {}", p.display());
421                self.tokenizer = Some(ModelAsset::from_path(p));
422            } else {
423                let p = dir.join("tekken.json");
424                if p.exists() {
425                    info!("Auto-discovered tokenizer: {}", p.display());
426                    self.tokenizer = Some(ModelAsset::from_path(p));
427                }
428            }
429        }
430
431        // Model weights
432        if self.weights.is_empty() {
433            let single = dir.join("model.safetensors");
434            if single.exists() {
435                info!("Auto-discovered single weight file");
436                self.weights.push(ModelAsset::from_path(single));
437            } else {
438                let single = dir.join("consolidated.safetensors");
439                if single.exists() {
440                    info!("Auto-discovered single weight file");
441                    self.weights.push(ModelAsset::from_path(single));
442                } else {
443                    self.discover_sharded_weights(dir);
444                }
445            }
446            // Fall back to .pth files (Kokoro uses PyTorch .pth format)
447            if self.weights.is_empty() {
448                self.discover_pth_weights(dir);
449            }
450        }
451
452        // Voice asset directory (Kokoro / Voxtral)
453        if self.voices_dir.is_none() {
454            let p = dir.join("voices");
455            if p.is_dir() {
456                info!("Auto-discovered voices dir: {}", p.display());
457                self.voices_dir = Some(ModelAssetDir::from_path(p));
458            } else {
459                let p = dir.join("voice_embedding");
460                if p.is_dir() {
461                    info!("Auto-discovered voices dir: {}", p.display());
462                    self.voices_dir = Some(ModelAssetDir::from_path(p));
463                }
464            }
465        }
466
467        // generation_config.json
468        if self.generation_config.is_none() {
469            let p = dir.join("generation_config.json");
470            if p.exists() {
471                info!("Auto-discovered generation config: {}", p.display());
472                self.generation_config = Some(ModelAsset::from_path(p));
473            }
474        }
475
476        if self.preprocessor_config.is_none() {
477            let p = dir.join("preprocessor_config.json");
478            if p.exists() {
479                info!("Auto-discovered preprocessor config: {}", p.display());
480                self.preprocessor_config = Some(ModelAsset::from_path(p));
481            }
482        }
483
484        for nested_dir_name in ["audio_tokenizer", "speech_tokenizer"] {
485            let nested_dir = dir.join(nested_dir_name);
486            if !nested_dir.is_dir() {
487                continue;
488            }
489
490            if self.speech_tokenizer_config.is_none() {
491                let p = nested_dir.join("config.json");
492                if p.exists() {
493                    info!(
494                        "Auto-discovered {} config: {}",
495                        nested_dir_name,
496                        p.display()
497                    );
498                    self.speech_tokenizer_config = Some(ModelAsset::from_path(p));
499                }
500            }
501
502            if self.speech_tokenizer_weights.is_empty() {
503                let single = nested_dir.join("model.safetensors");
504                if single.exists() {
505                    info!("Auto-discovered {} weight file", nested_dir_name);
506                    self.speech_tokenizer_weights
507                        .push(ModelAsset::from_path(single));
508                } else {
509                    let mut shards = Self::discover_sharded_weights_in_dir(&nested_dir);
510                    if !shards.is_empty() {
511                        info!(
512                            "Auto-discovered {} {} weight shards",
513                            shards.len(),
514                            nested_dir_name
515                        );
516                        self.speech_tokenizer_weights.append(&mut shards);
517                    }
518                }
519            }
520        }
521    }
522
523    /// Scan an in-memory asset bundle for well-known model files.
524    pub fn fill_from_asset_bundle(&mut self, bundle: &ModelAssetBundle) {
525        if self.config.is_none() {
526            self.config = bundle
527                .get("config.json")
528                .or_else(|| bundle.get("params.json"));
529        }
530
531        if self.tokenizer.is_none() {
532            self.tokenizer = bundle
533                .get("tokenizer.json")
534                .or_else(|| bundle.get("tekken.json"));
535        }
536
537        if self.weights.is_empty() {
538            if let Some(asset) = bundle.get("model.safetensors") {
539                self.weights.push(asset);
540            } else if let Some(asset) = bundle.get("consolidated.safetensors") {
541                self.weights.push(asset);
542            } else {
543                self.weights = bundle.discover_sharded_weights("");
544            }
545            if self.weights.is_empty() {
546                self.weights = bundle.discover_pth_weights("");
547            }
548        }
549
550        if self.voices_dir.is_none() {
551            self.voices_dir = bundle
552                .collect_directory("voices")
553                .or_else(|| bundle.collect_directory("voice_embedding"));
554        }
555
556        if self.generation_config.is_none() {
557            self.generation_config = bundle.get("generation_config.json");
558        }
559
560        if self.preprocessor_config.is_none() {
561            self.preprocessor_config = bundle.get("preprocessor_config.json");
562        }
563
564        for nested_dir_name in ["audio_tokenizer", "speech_tokenizer"] {
565            if self.speech_tokenizer_config.is_none() {
566                self.speech_tokenizer_config =
567                    bundle.get(format!("{nested_dir_name}/config.json").as_str());
568            }
569
570            if self.speech_tokenizer_weights.is_empty() {
571                if let Some(asset) =
572                    bundle.get(format!("{nested_dir_name}/model.safetensors").as_str())
573                {
574                    self.speech_tokenizer_weights.push(asset);
575                } else {
576                    self.speech_tokenizer_weights =
577                        bundle.discover_sharded_weights(nested_dir_name);
578                }
579            }
580        }
581    }
582
583    /// Look for `.pth` PyTorch weight files (e.g. Kokoro's kokoro-v1_0.pth).
584    fn discover_pth_weights(&mut self, dir: &Path) {
585        let Ok(entries) = std::fs::read_dir(dir) else {
586            return;
587        };
588
589        let mut pth_files: Vec<ModelAsset> = entries
590            .filter_map(|e| e.ok())
591            .map(|e| e.path())
592            .filter(|p| {
593                p.extension()
594                    .and_then(|ext| ext.to_str())
595                    .is_some_and(|ext| ext == "pth")
596            })
597            .map(ModelAsset::from_path)
598            .collect();
599
600        if !pth_files.is_empty() {
601            pth_files.sort_by_key(ModelAsset::display_name);
602            info!("Auto-discovered {} .pth weight file(s)", pth_files.len());
603            self.weights = pth_files;
604        }
605    }
606
607    /// Look for `model-NNNNN-of-NNNNN.safetensors` shard files.
608    fn discover_sharded_weights(&mut self, dir: &Path) {
609        let shards = Self::discover_sharded_weights_in_dir(dir);
610
611        if !shards.is_empty() {
612            info!("Auto-discovered {} weight shards", shards.len());
613            self.weights = shards;
614        }
615    }
616
617    fn discover_sharded_weights_in_dir(dir: &Path) -> Vec<ModelAsset> {
618        let Ok(entries) = std::fs::read_dir(dir) else {
619            return Vec::new();
620        };
621
622        let mut shards: Vec<ModelAsset> = entries
623            .filter_map(|e| e.ok())
624            .map(|e| e.path())
625            .filter(|p| {
626                p.file_name()
627                    .and_then(|n| n.to_str())
628                    .is_some_and(|n| n.starts_with("model-") && n.ends_with(".safetensors"))
629            })
630            .map(ModelAsset::from_path)
631            .collect();
632        shards.sort_by_key(ModelAsset::display_name);
633        shards
634    }
635
636    /// Build a [`VarBuilder`] by reading safetensors files fully into memory.
637    ///
638    /// This is the **safe** alternative to `VarBuilder::from_mmaped_safetensors`
639    /// which requires `unsafe` due to memory-mapping. The trade-off is a brief
640    /// peak in memory while the raw bytes and parsed tensors coexist, but for
641    /// model loading this is negligible compared to the final tensor footprint.
642    pub fn load_safetensors_vb(
643        assets: &[ModelAsset],
644        dtype: candle_core::DType,
645        device: &candle_core::Device,
646    ) -> Result<VarBuilder<'static>, TtsError> {
647        if assets.is_empty() {
648            return Err(TtsError::FileMissing("safetensors weight files".into()));
649        }
650
651        if assets.len() == 1 {
652            if let Some(path) = assets[0].as_path() {
653                let data = std::fs::read(path).map_err(|e| {
654                    TtsError::WeightLoadError(format!("Failed to read {}: {}", path.display(), e))
655                })?;
656                return VarBuilder::from_buffered_safetensors(data, dtype, device)
657                    .map_err(|e| TtsError::WeightLoadError(e.to_string()));
658            }
659        }
660
661        // Multi-file: read each shard, collect all tensors into one HashMap
662        let mut all_tensors: HashMap<String, candle_core::Tensor> = HashMap::new();
663        for asset in assets {
664            let data = asset.read_bytes().map_err(|e| {
665                TtsError::WeightLoadError(format!("Failed to read {}: {}", asset.display_name(), e))
666            })?;
667            let tensors = safetensors::SafeTensors::deserialize(&data).map_err(|e| {
668                TtsError::WeightLoadError(format!(
669                    "Failed to parse {}: {}",
670                    asset.display_name(),
671                    e
672                ))
673            })?;
674            for (name, view) in tensors.tensors() {
675                // Get the native dtype from the safetensors file
676                let native_dtype = match view.dtype() {
677                    safetensors::Dtype::F16 => candle_core::DType::F16,
678                    safetensors::Dtype::BF16 => candle_core::DType::BF16,
679                    safetensors::Dtype::F32 => candle_core::DType::F32,
680                    safetensors::Dtype::F64 => candle_core::DType::F64,
681                    safetensors::Dtype::I64 => candle_core::DType::I64,
682                    safetensors::Dtype::I32 => candle_core::DType::I64, // candle has no I32
683                    safetensors::Dtype::U32 => candle_core::DType::U32,
684                    safetensors::Dtype::U8 => candle_core::DType::U8,
685                    _ => candle_core::DType::F32, // Fallback
686                };
687
688                // Load in native dtype first
689                let tensor = candle_core::Tensor::from_raw_buffer(
690                    view.data(),
691                    native_dtype,
692                    view.shape(),
693                    device,
694                )
695                .map_err(|e| {
696                    TtsError::WeightLoadError(format!("Failed to load tensor '{}': {}", name, e))
697                })?;
698
699                // Convert to target dtype if different
700                let tensor = if native_dtype != dtype {
701                    tensor.to_dtype(dtype).map_err(|e| {
702                        TtsError::WeightLoadError(format!(
703                            "Failed to convert tensor '{}' to {:?}: {}",
704                            name, dtype, e
705                        ))
706                    })?
707                } else {
708                    tensor
709                };
710
711                all_tensors.insert(name, tensor);
712            }
713        }
714
715        Ok(VarBuilder::from_tensors(all_tensors, dtype, device))
716    }
717
718    /// Download missing files from HuggingFace Hub.
719    ///
720    /// `model_type` determines which files are required.
721    #[cfg(feature = "download")]
722    pub fn fill_from_hf(
723        &mut self,
724        model_id: &str,
725        model_type: ModelType,
726        bearer_token: Option<&str>,
727    ) -> Result<(), TtsError> {
728        use crate::download::download_file_with_token;
729
730        let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
731
732        // config
733        if self.config.is_none() {
734            let config_name = if model_type == ModelType::Voxtral {
735                "params.json"
736            } else {
737                "config.json"
738            };
739            info!("Downloading {} from {}", config_name, model_id);
740            self.config = Some(ModelAsset::from_path(download(model_id, config_name)?));
741        }
742
743        // tokenizer.json (Kokoro uses phoneme vocab from config.json instead)
744        if model_type != ModelType::Kokoro && self.tokenizer.is_none() {
745            let tokenizer_name = if model_type == ModelType::Voxtral {
746                "tekken.json"
747            } else {
748                "tokenizer.json"
749            };
750            info!("Downloading {} from {}", tokenizer_name, model_id);
751            match download(model_id, tokenizer_name) {
752                Ok(p) => self.tokenizer = Some(ModelAsset::from_path(p)),
753                Err(_) => {
754                    if model_type == ModelType::Voxtral {
755                        return Err(TtsError::FileMissing(
756                            "tekken.json — Voxtral Tekken tokenizer".to_string(),
757                        ));
758                    }
759                    let fallback_repo = match model_type {
760                        ModelType::Qwen3Tts => "Qwen/Qwen2.5-0.5B",
761                        ModelType::VibeVoice => "Qwen/Qwen2.5-1.5B",
762                        ModelType::VibeVoiceRealtime => "Qwen/Qwen2.5-0.5B",
763                        _ => "Qwen/Qwen2.5-0.5B",
764                    };
765                    info!(
766                        "tokenizer.json not in {}; falling back to {}",
767                        model_id, fallback_repo
768                    );
769                    self.tokenizer = Some(ModelAsset::from_path(download(
770                        fallback_repo,
771                        "tokenizer.json",
772                    )?));
773                }
774            }
775        }
776
777        // generation_config.json (optional — ignore download errors)
778        if self.generation_config.is_none() {
779            if let Ok(p) = download(model_id, "generation_config.json") {
780                self.generation_config = Some(ModelAsset::from_path(p));
781            }
782        }
783
784        if self.preprocessor_config.is_none() {
785            if let Ok(p) = download(model_id, "preprocessor_config.json") {
786                self.preprocessor_config = Some(ModelAsset::from_path(p));
787            }
788        }
789
790        // Model weights
791        if self.weights.is_empty() {
792            self.download_weights_from_hf(model_id, bearer_token)?;
793        }
794
795        // Model-specific extras
796        match model_type {
797            ModelType::Kokoro => {
798                self.download_kokoro_extras(model_id, bearer_token)?;
799            }
800            ModelType::OmniVoice => {
801                self.download_omnivoice_extras(model_id, bearer_token)?;
802            }
803            ModelType::Voxtral => {
804                self.download_voxtral_extras(model_id, bearer_token)?;
805            }
806            ModelType::Qwen3Tts => {
807                self.download_qwen3tts_extras(bearer_token)?;
808            }
809            ModelType::VibeVoice | ModelType::VibeVoiceRealtime => {
810                self.download_vibevoice_extras(model_id, bearer_token)?;
811            }
812        }
813
814        Ok(())
815    }
816
817    /// Download weight files (single or sharded) from HuggingFace.
818    #[cfg(feature = "download")]
819    fn download_weights_from_hf(
820        &mut self,
821        model_id: &str,
822        bearer_token: Option<&str>,
823    ) -> Result<(), TtsError> {
824        use crate::download::download_file_with_token;
825
826        let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
827
828        // Try single safetensors file first
829        if let Ok(p) = download(model_id, "model.safetensors") {
830            self.weights.push(ModelAsset::from_path(p));
831            return Ok(());
832        }
833
834        // Voxtral ships a single consolidated.safetensors file.
835        if let Ok(p) = download(model_id, "consolidated.safetensors") {
836            self.weights.push(ModelAsset::from_path(p));
837            return Ok(());
838        }
839
840        // Try .pth file (Kokoro uses PyTorch format)
841        for pth_name in &["kokoro-v1_0.pth", "kokoro-v1_1-zh.pth", "model.pth"] {
842            if let Ok(p) = download(model_id, pth_name) {
843                self.weights.push(ModelAsset::from_path(p));
844                return Ok(());
845            }
846        }
847
848        // Fall back to sharded — download the index
849        let index_path = download(model_id, "model.safetensors.index.json")?;
850        let index_content = std::fs::read_to_string(&index_path)?;
851        let index: serde_json::Value = serde_json::from_str(&index_content)?;
852
853        if let Some(weight_map) = index.get("weight_map").and_then(|v| v.as_object()) {
854            let mut shard_names: Vec<String> = weight_map
855                .values()
856                .filter_map(|v| v.as_str().map(String::from))
857                .collect();
858            shard_names.sort();
859            shard_names.dedup();
860
861            for shard_name in &shard_names {
862                info!("Downloading shard: {}", shard_name);
863                let p = download(model_id, shard_name)?;
864                self.weights.push(ModelAsset::from_path(p));
865            }
866        }
867
868        Ok(())
869    }
870
871    /// Download Kokoro-specific files (voices directory).
872    #[cfg(feature = "download")]
873    fn download_kokoro_extras(
874        &mut self,
875        model_id: &str,
876        bearer_token: Option<&str>,
877    ) -> Result<(), TtsError> {
878        use crate::download::download_file_with_token;
879
880        let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
881
882        if self.voices_dir.is_none() {
883            // Download a well-known voice to discover the voices directory
884            if let Ok(voice_path) = download(model_id, "voices/af_heart.pt") {
885                if let Some(parent) = voice_path.parent() {
886                    self.voices_dir = Some(ModelAssetDir::from_path(parent.to_path_buf()));
887                }
888            }
889        }
890
891        Ok(())
892    }
893
894    /// Download Qwen3-TTS-specific files (speech tokenizer from separate repo).
895    #[cfg(feature = "download")]
896    fn download_qwen3tts_extras(&mut self, bearer_token: Option<&str>) -> Result<(), TtsError> {
897        use crate::download::download_file_with_token;
898
899        let tokenizer_repo = "Qwen/Qwen3-TTS-Tokenizer-12Hz";
900        let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
901
902        if self.speech_tokenizer_config.is_none() {
903            info!(
904                "Downloading speech tokenizer config from {}",
905                tokenizer_repo
906            );
907            if let Ok(p) = download(tokenizer_repo, "config.json") {
908                self.speech_tokenizer_config = Some(ModelAsset::from_path(p));
909            }
910        }
911
912        if self.speech_tokenizer_weights.is_empty() {
913            info!(
914                "Downloading speech tokenizer weights from {}",
915                tokenizer_repo
916            );
917            if let Ok(p) = download(tokenizer_repo, "model.safetensors") {
918                self.speech_tokenizer_weights.push(ModelAsset::from_path(p));
919            } else if let Ok(index_path) = download(tokenizer_repo, "model.safetensors.index.json")
920            {
921                if let Ok(content) = std::fs::read_to_string(&index_path) {
922                    if let Ok(index) = serde_json::from_str::<serde_json::Value>(&content) {
923                        if let Some(weight_map) =
924                            index.get("weight_map").and_then(|v| v.as_object())
925                        {
926                            let mut shard_names: Vec<String> = weight_map
927                                .values()
928                                .filter_map(|v| v.as_str().map(String::from))
929                                .collect();
930                            shard_names.sort();
931                            shard_names.dedup();
932
933                            for shard_name in &shard_names {
934                                if let Ok(p) = download(tokenizer_repo, shard_name) {
935                                    self.speech_tokenizer_weights.push(ModelAsset::from_path(p));
936                                }
937                            }
938                        }
939                    }
940                }
941            }
942        }
943
944        Ok(())
945    }
946
947    #[cfg(feature = "download")]
948    fn download_vibevoice_extras(
949        &mut self,
950        model_id: &str,
951        bearer_token: Option<&str>,
952    ) -> Result<(), TtsError> {
953        use crate::download::download_file_with_token;
954
955        let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
956
957        if self.preprocessor_config.is_none() {
958            if let Ok(p) = download(model_id, "preprocessor_config.json") {
959                self.preprocessor_config = Some(ModelAsset::from_path(p));
960            }
961        }
962
963        Ok(())
964    }
965
966    /// Download OmniVoice-specific files (audio tokenizer subdirectory).
967    #[cfg(feature = "download")]
968    fn download_omnivoice_extras(
969        &mut self,
970        model_id: &str,
971        bearer_token: Option<&str>,
972    ) -> Result<(), TtsError> {
973        use crate::download::download_file_with_token;
974
975        let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
976
977        if self.speech_tokenizer_config.is_none() {
978            if let Ok(p) = download(model_id, "audio_tokenizer/config.json") {
979                self.speech_tokenizer_config = Some(ModelAsset::from_path(p));
980            }
981        }
982
983        if self.speech_tokenizer_weights.is_empty() {
984            if let Ok(p) = download(model_id, "audio_tokenizer/model.safetensors") {
985                self.speech_tokenizer_weights.push(ModelAsset::from_path(p));
986            } else if let Ok(index_path) =
987                download(model_id, "audio_tokenizer/model.safetensors.index.json")
988            {
989                if let Ok(content) = std::fs::read_to_string(&index_path) {
990                    if let Ok(index) = serde_json::from_str::<serde_json::Value>(&content) {
991                        if let Some(weight_map) =
992                            index.get("weight_map").and_then(|v| v.as_object())
993                        {
994                            let mut shard_names: Vec<String> = weight_map
995                                .values()
996                                .filter_map(|v| v.as_str().map(String::from))
997                                .collect();
998                            shard_names.sort();
999                            shard_names.dedup();
1000
1001                            for shard_name in &shard_names {
1002                                let shard_path = format!("audio_tokenizer/{}", shard_name);
1003                                if let Ok(p) = download(model_id, &shard_path) {
1004                                    self.speech_tokenizer_weights.push(ModelAsset::from_path(p));
1005                                }
1006                            }
1007                        }
1008                    }
1009                }
1010            }
1011        }
1012
1013        Ok(())
1014    }
1015
1016    /// Download Voxtral-specific files (preset voice embeddings).
1017    #[cfg(feature = "download")]
1018    fn download_voxtral_extras(
1019        &mut self,
1020        model_id: &str,
1021        bearer_token: Option<&str>,
1022    ) -> Result<(), TtsError> {
1023        use crate::download::download_file_with_token;
1024
1025        let download = |repo: &str, file: &str| download_file_with_token(repo, file, bearer_token);
1026
1027        if self.voices_dir.is_some() {
1028            return Ok(());
1029        }
1030
1031        let config_path = self.config.as_ref().ok_or_else(|| {
1032            TtsError::FileMissing("params.json — Voxtral model configuration".to_string())
1033        })?;
1034        let content = config_path.read_bytes()?;
1035        let config: serde_json::Value = serde_json::from_slice(&content)?;
1036        let voices = config
1037            .get("multimodal")
1038            .and_then(|v| v.get("audio_tokenizer_args"))
1039            .and_then(|v| v.get("voice"))
1040            .and_then(|v| v.as_object())
1041            .ok_or_else(|| {
1042                TtsError::ConfigError(
1043                    "params.json is missing multimodal.audio_tokenizer_args.voice".to_string(),
1044                )
1045            })?;
1046
1047        let mut discovered_dir: Option<ModelAssetDir> = None;
1048        for voice_name in voices.keys() {
1049            let path = download(model_id, &format!("voice_embedding/{voice_name}.pt"))?;
1050            if discovered_dir.is_none() {
1051                discovered_dir = path
1052                    .parent()
1053                    .map(|parent| ModelAssetDir::from_path(parent.to_path_buf()));
1054            }
1055        }
1056
1057        self.voices_dir = discovered_dir;
1058        Ok(())
1059    }
1060
1061    /// Check whether all required files for the given model type are present.
1062    pub fn validate(&self, model_type: ModelType) -> Result<(), TtsError> {
1063        if model_type == ModelType::Voxtral {
1064            if self.config.is_none() {
1065                return Err(TtsError::FileMissing(
1066                    "params.json — Voxtral model configuration".to_string(),
1067                ));
1068            }
1069            if self.tokenizer.is_none() {
1070                return Err(TtsError::FileMissing(
1071                    "tekken.json — Voxtral Tekken tokenizer".to_string(),
1072                ));
1073            }
1074            if self.weights.is_empty() {
1075                return Err(TtsError::FileMissing(
1076                    "consolidated.safetensors — Voxtral model weights".to_string(),
1077                ));
1078            }
1079            if self.voices_dir.is_none() {
1080                return Err(TtsError::FileMissing(
1081                    "voice_embedding/ — Voxtral preset voice embeddings".to_string(),
1082                ));
1083            }
1084            return Ok(());
1085        }
1086
1087        if self.config.is_none() {
1088            return Err(TtsError::FileMissing(
1089                "config.json — model architecture configuration".to_string(),
1090            ));
1091        }
1092        // Kokoro uses phoneme vocab from config.json, no tokenizer.json needed
1093        if model_type != ModelType::Kokoro && self.tokenizer.is_none() {
1094            return Err(TtsError::FileMissing(
1095                "tokenizer.json — BPE text tokenizer".to_string(),
1096            ));
1097        }
1098        if self.weights.is_empty() {
1099            return Err(TtsError::FileMissing(
1100                "model weight files (.safetensors or .pth)".to_string(),
1101            ));
1102        }
1103
1104        match model_type {
1105            ModelType::OmniVoice => {
1106                if self.speech_tokenizer_config.is_none() {
1107                    return Err(TtsError::FileMissing(
1108                        "audio tokenizer config (audio_tokenizer/config.json) \
1109                         — configures OmniVoice's codec decoder"
1110                            .to_string(),
1111                    ));
1112                }
1113                if self.speech_tokenizer_weights.is_empty() {
1114                    return Err(TtsError::FileMissing(
1115                        "audio tokenizer weights (audio_tokenizer/model.safetensors) \
1116                         — converts OmniVoice codec tokens to audio waveform"
1117                            .to_string(),
1118                    ));
1119                }
1120            }
1121            ModelType::Qwen3Tts => {
1122                if self.speech_tokenizer_weights.is_empty() {
1123                    return Err(TtsError::FileMissing(
1124                        "speech tokenizer weights (Qwen3-TTS-Tokenizer-12Hz model.safetensors) \
1125                         — converts codec tokens to audio waveform"
1126                            .to_string(),
1127                    ));
1128                }
1129            }
1130            ModelType::Kokoro => {
1131                // voices_dir is optional
1132            }
1133            ModelType::VibeVoice | ModelType::VibeVoiceRealtime => {}
1134            ModelType::Voxtral => unreachable!(),
1135        }
1136
1137        Ok(())
1138    }
1139
1140    /// Return the list of files that are required but not yet set.
1141    pub fn missing_files(&self, model_type: ModelType) -> Vec<&'static str> {
1142        if model_type == ModelType::Voxtral {
1143            let mut missing = Vec::new();
1144            if self.config.is_none() {
1145                missing.push("params.json");
1146            }
1147            if self.tokenizer.is_none() {
1148                missing.push("tekken.json");
1149            }
1150            if self.weights.is_empty() {
1151                missing.push("consolidated.safetensors");
1152            }
1153            if self.voices_dir.is_none() {
1154                missing.push("voice_embedding");
1155            }
1156            return missing;
1157        }
1158
1159        let mut missing = Vec::new();
1160
1161        if self.config.is_none() {
1162            missing.push("config.json");
1163        }
1164        if model_type != ModelType::Kokoro && self.tokenizer.is_none() {
1165            missing.push("tokenizer.json");
1166        }
1167        if self.weights.is_empty() {
1168            missing.push("model weight files");
1169        }
1170        if model_type == ModelType::OmniVoice && self.speech_tokenizer_config.is_none() {
1171            missing.push("audio tokenizer config");
1172        }
1173        if model_type == ModelType::OmniVoice && self.speech_tokenizer_weights.is_empty() {
1174            missing.push("audio tokenizer weights");
1175        }
1176        if model_type == ModelType::Qwen3Tts && self.speech_tokenizer_weights.is_empty() {
1177            missing.push("speech tokenizer weights");
1178        }
1179
1180        missing
1181    }
1182}
1183
1184// ---------------------------------------------------------------------------
1185// DType
1186// ---------------------------------------------------------------------------
1187
1188/// Floating-point data type for model weights.
1189#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1190pub enum DType {
1191    /// 32-bit float — maximum compatibility, highest memory.
1192    F32,
1193    /// 16-bit float — good balance.
1194    F16,
1195    /// Brain float 16 — preferred for transformer models.
1196    #[default]
1197    BF16,
1198}
1199
1200impl DType {
1201    /// Convert to candle's DType.
1202    pub fn to_candle(self) -> candle_core::DType {
1203        match self {
1204            Self::F32 => candle_core::DType::F32,
1205            Self::F16 => candle_core::DType::F16,
1206            Self::BF16 => candle_core::DType::BF16,
1207        }
1208    }
1209
1210    /// Human-readable dtype label.
1211    pub fn label(self) -> &'static str {
1212        match self {
1213            Self::F32 => "f32",
1214            Self::F16 => "f16",
1215            Self::BF16 => "bf16",
1216        }
1217    }
1218}
1219
1220/// Preferred runtime choice for a model on the current machine.
1221///
1222/// This reflects the compiled backend features, runtime hardware
1223/// availability, and the crate's model-specific dtype safety rules.
1224#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1225pub struct RuntimeChoice {
1226    /// Concrete device selection for this runtime.
1227    pub device: DeviceSelection,
1228    /// Recommended dtype for the chosen device and model.
1229    pub dtype: DType,
1230}
1231
1232impl RuntimeChoice {
1233    /// Combined device and dtype label.
1234    pub fn label(&self) -> String {
1235        format!("{} ({})", self.device.label(), self.dtype.label())
1236    }
1237}
1238
1239/// Preferred runtime choices for a model, ordered fastest-first.
1240pub fn preferred_runtime_choices(model_type: ModelType) -> Vec<RuntimeChoice> {
1241    DeviceSelection::available_runtime_candidates()
1242        .into_iter()
1243        .map(|device| RuntimeChoice {
1244            device,
1245            dtype: preferred_dtype_for(model_type, device),
1246        })
1247        .collect()
1248}
1249
1250/// Best runtime choice for a model on the current machine.
1251pub fn preferred_runtime_choice(model_type: ModelType) -> RuntimeChoice {
1252    preferred_runtime_choices(model_type)
1253        .into_iter()
1254        .next()
1255        .unwrap_or(RuntimeChoice {
1256            device: DeviceSelection::Cpu,
1257            dtype: DType::F32,
1258        })
1259}
1260
1261fn preferred_dtype_for(model_type: ModelType, device: DeviceSelection) -> DType {
1262    match model_type {
1263        ModelType::OmniVoice => match device {
1264            DeviceSelection::Cpu => DType::F32,
1265            DeviceSelection::Cuda(_) => DType::BF16,
1266            DeviceSelection::Metal(_) => DType::F32,
1267            DeviceSelection::Auto => DType::BF16,
1268        },
1269        ModelType::Kokoro => match device {
1270            DeviceSelection::Cpu => DType::F32,
1271            DeviceSelection::Cuda(_) => DType::BF16,
1272            DeviceSelection::Metal(_) => DType::F32,
1273            DeviceSelection::Auto => DType::BF16,
1274        },
1275        ModelType::Qwen3Tts => match device {
1276            DeviceSelection::Cpu => DType::F32,
1277            DeviceSelection::Cuda(_) => DType::BF16,
1278            DeviceSelection::Metal(_) => DType::BF16,
1279            DeviceSelection::Auto => DType::BF16,
1280        },
1281        ModelType::VibeVoice | ModelType::VibeVoiceRealtime => match device {
1282            DeviceSelection::Cpu => DType::F32,
1283            DeviceSelection::Cuda(_) => DType::BF16,
1284            DeviceSelection::Metal(_) => DType::F32,
1285            DeviceSelection::Auto => DType::BF16,
1286        },
1287        ModelType::Voxtral => match device {
1288            DeviceSelection::Cpu => DType::F32,
1289            DeviceSelection::Cuda(_) => DType::BF16,
1290            DeviceSelection::Metal(_) => DType::F32,
1291            DeviceSelection::Auto => DType::BF16,
1292        },
1293    }
1294}
1295
1296// ---------------------------------------------------------------------------
1297// TtsConfig — main configuration + builder
1298// ---------------------------------------------------------------------------
1299
1300/// Top-level configuration for loading a TTS model.
1301///
1302/// # Providing model files
1303///
1304/// There are four ways to tell any-tts where to find the model files,
1305/// listed from highest to lowest priority:
1306///
1307/// 1. **Individual file paths** (for custom download managers):
1308///    ```rust
1309///    # use any_tts::{TtsConfig, ModelType};
1310///    let config = TtsConfig::new(ModelType::Qwen3Tts)
1311///        .with_config_file("/cache/sha256-abc/config.json")
1312///        .with_tokenizer_file("/cache/sha256-def/tokenizer.json")
1313///        .with_weight_file("/cache/sha256-012/model.safetensors");
1314///    ```
1315///
1316/// 2. **Named in-memory assets** (for object stores and byte-first runtimes):
1317///    ```rust
1318///    # use any_tts::{ModelAssetBundle, ModelType, TtsConfig};
1319///    let bundle = ModelAssetBundle::new()
1320///        .with_bytes("config.json", vec![])
1321///        .with_bytes("tokenizer.json", vec![])
1322///        .with_bytes("model.safetensors", vec![]);
1323///    let config = TtsConfig::new(ModelType::Qwen3Tts)
1324///        .with_asset_bundle(bundle);
1325///    ```
1326///
1327/// 3. **Directory path** (all files in one folder):
1328///    ```rust
1329///    # use any_tts::{TtsConfig, ModelType};
1330///    let config = TtsConfig::new(ModelType::Qwen3Tts)
1331///        .with_model_path("/models/qwen3-tts");
1332///    ```
1333///
1334/// 4. **HuggingFace Hub download** (automatic fallback):
1335///    ```rust
1336///    # use any_tts::{TtsConfig, ModelType};
1337///    let config = TtsConfig::new(ModelType::Qwen3Tts); // downloads automatically
1338///    ```
1339///
1340/// These can be mixed: set some files explicitly and let the rest be
1341/// auto-discovered or downloaded.
1342#[derive(Debug, Clone)]
1343pub struct TtsConfig {
1344    /// Which model backend to use.
1345    pub model_type: ModelType,
1346
1347    /// Path to a local directory containing model weights and config.
1348    /// Files inside are auto-discovered by well-known filenames.
1349    ///
1350    /// Some backends may also accept a local model directory instead of a
1351    /// HuggingFace model ID.
1352    pub model_path: Option<String>,
1353
1354    /// HuggingFace model ID (e.g. `"Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice"`).
1355    /// Used as download fallback when files are not found locally.
1356    pub hf_model_id: Option<String>,
1357
1358    /// Override for an external runtime command when a backend needs one.
1359    pub runtime_command: Option<String>,
1360
1361    /// Override for an external runtime endpoint when a backend needs one.
1362    pub runtime_endpoint: Option<String>,
1363
1364    /// Bearer token used for external HTTP runtimes.
1365    pub bearer_token: Option<String>,
1366
1367    /// Device selection strategy.
1368    pub device: DeviceSelection,
1369
1370    /// Data type for model weights. Defaults to BFloat16 where supported.
1371    pub dtype: DType,
1372
1373    /// Individually specified model files.
1374    pub files: ModelFiles,
1375
1376    /// Named byte assets that can be auto-discovered like a model directory.
1377    pub asset_bundle: ModelAssetBundle,
1378}
1379
1380impl TtsConfig {
1381    /// Create a new configuration for the specified model type.
1382    pub fn new(model_type: ModelType) -> Self {
1383        Self {
1384            model_type,
1385            model_path: None,
1386            hf_model_id: None,
1387            runtime_command: None,
1388            runtime_endpoint: None,
1389            bearer_token: None,
1390            device: DeviceSelection::Auto,
1391            dtype: DType::default(),
1392            files: ModelFiles::default(),
1393            asset_bundle: ModelAssetBundle::default(),
1394        }
1395    }
1396
1397    // ── Directory / HF shortcuts ──────────────────────────────────────
1398
1399    /// Set the local directory containing all model files.
1400    ///
1401    /// The directory will be scanned for well-known filenames
1402    /// (`config.json`, `tokenizer.json`, `model.safetensors`, …).
1403    pub fn with_model_path(mut self, path: impl Into<String>) -> Self {
1404        self.model_path = Some(path.into());
1405        self
1406    }
1407
1408    /// Add a complete in-memory asset bundle using model-relative paths.
1409    pub fn with_asset_bundle(mut self, bundle: ModelAssetBundle) -> Self {
1410        self.asset_bundle = bundle;
1411        self
1412    }
1413
1414    /// Add a single in-memory asset using a model-relative path such as
1415    /// `config.json` or `audio_tokenizer/model.safetensors`.
1416    pub fn with_asset_bytes(
1417        mut self,
1418        relative_path: impl Into<String>,
1419        bytes: impl Into<Vec<u8>>,
1420    ) -> Self {
1421        self.asset_bundle.insert_bytes(relative_path, bytes);
1422        self
1423    }
1424
1425    /// Set the HuggingFace model ID for automatic download.
1426    ///
1427    /// Only used as a fallback when files cannot be found locally.
1428    pub fn with_hf_model_id(mut self, id: impl Into<String>) -> Self {
1429        self.hf_model_id = Some(id.into());
1430        self
1431    }
1432
1433    /// Override the executable used by runtime-adapter backends.
1434    pub fn with_runtime_command(mut self, command: impl Into<String>) -> Self {
1435        self.runtime_command = Some(command.into());
1436        self
1437    }
1438
1439    /// Override the HTTP endpoint used by runtime-adapter backends.
1440    pub fn with_runtime_endpoint(mut self, endpoint: impl Into<String>) -> Self {
1441        self.runtime_endpoint = Some(endpoint.into());
1442        self
1443    }
1444
1445    /// Set the bearer token used by HTTP runtime adapters.
1446    pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
1447        self.bearer_token = Some(token.into());
1448        self
1449    }
1450
1451    // ── Device / dtype ────────────────────────────────────────────────
1452
1453    /// Set the device selection strategy.
1454    pub fn with_device(mut self, device: DeviceSelection) -> Self {
1455        self.device = device;
1456        self
1457    }
1458
1459    /// Set the data type for model weights.
1460    pub fn with_dtype(mut self, dtype: DType) -> Self {
1461        self.dtype = dtype;
1462        self
1463    }
1464
1465    /// Apply the fastest safe runtime choice for this model on the current machine.
1466    ///
1467    /// This resolves the current machine's preferred backend and dtype now,
1468    /// then stores the concrete selection in the config.
1469    pub fn with_preferred_runtime(mut self) -> Self {
1470        let runtime = preferred_runtime_choice(self.model_type);
1471        self.device = runtime.device;
1472        self.dtype = runtime.dtype;
1473        self
1474    }
1475
1476    // ── Individual file builders ──────────────────────────────────────
1477
1478    /// Set the path to **`config.json`**.
1479    ///
1480    /// This JSON file describes the model architecture: hidden size,
1481    /// number of layers, vocabulary size, attention head counts, etc.
1482    /// It follows the standard HuggingFace config format.
1483    pub fn with_config_file(mut self, path: impl Into<PathBuf>) -> Self {
1484        self.files.config = Some(ModelAsset::from_path(path.into()));
1485        self
1486    }
1487
1488    /// Set `config.json` from in-memory bytes.
1489    pub fn with_config_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1490        self.files.config = Some(ModelAsset::from_bytes("config.json", bytes));
1491        self
1492    }
1493
1494    /// Set the path to **`tokenizer.json`**.
1495    ///
1496    /// A self-contained HuggingFace Tokenizers JSON file with the full
1497    /// BPE vocabulary, merge rules, and special-token definitions.
1498    /// Both model backends use this to convert input text to token IDs.
1499    pub fn with_tokenizer_file(mut self, path: impl Into<PathBuf>) -> Self {
1500        self.files.tokenizer = Some(ModelAsset::from_path(path.into()));
1501        self
1502    }
1503
1504    /// Set `tokenizer.json` from in-memory bytes.
1505    pub fn with_tokenizer_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1506        self.files.tokenizer = Some(ModelAsset::from_bytes("tokenizer.json", bytes));
1507        self
1508    }
1509
1510    /// Append a single **model weight file** (`.safetensors`).
1511    ///
1512    /// Call this repeatedly when you need to provide several shards
1513    /// explicitly, or once for single-file models:
1514    ///
1515    /// ```rust
1516    /// # use any_tts::{TtsConfig, ModelType};
1517    /// let config = TtsConfig::new(ModelType::Qwen3Tts)
1518    ///     .with_weight_file("/cache/model.safetensors");
1519    /// ```
1520    pub fn with_weight_file(mut self, path: impl Into<PathBuf>) -> Self {
1521        self.files.weights.push(ModelAsset::from_path(path.into()));
1522        self
1523    }
1524
1525    /// Append a single in-memory weight file.
1526    pub fn with_weight_bytes(
1527        mut self,
1528        file_name: impl Into<String>,
1529        bytes: impl Into<Vec<u8>>,
1530    ) -> Self {
1531        self.files
1532            .weights
1533            .push(ModelAsset::from_bytes(file_name.into(), bytes));
1534        self
1535    }
1536
1537    /// Set **all model weight files** at once, replacing any previously added.
1538    pub fn with_weight_files(mut self, paths: Vec<PathBuf>) -> Self {
1539        self.files.weights = paths.into_iter().map(ModelAsset::from_path).collect();
1540        self
1541    }
1542
1543    /// Set a **voice asset directory** for backends that use preset voices.
1544    ///
1545    /// This is used by backends such as Kokoro (`voices/*.pt`) and
1546    /// Voxtral (`voice_embedding/*.pt`).
1547    pub fn with_voices_dir(mut self, path: impl Into<PathBuf>) -> Self {
1548        self.files.voices_dir = Some(ModelAssetDir::from_path(path.into()));
1549        self
1550    }
1551
1552    /// Add a single in-memory preset voice asset.
1553    pub fn with_voice_bytes(
1554        mut self,
1555        voice_name: impl Into<String>,
1556        bytes: impl Into<Vec<u8>>,
1557    ) -> Self {
1558        let voice_file = format!("{}.pt", voice_name.into());
1559        match self.files.voices_dir.take() {
1560            Some(ModelAssetDir::Bytes(mut entries)) => {
1561                entries.insert(voice_file, Arc::from(bytes.into()));
1562                self.files.voices_dir = Some(ModelAssetDir::from_bytes(entries));
1563            }
1564            Some(ModelAssetDir::Path(path)) => {
1565                self.files.voices_dir = Some(ModelAssetDir::Path(path));
1566                self.asset_bundle
1567                    .insert_bytes(format!("voices/{voice_file}"), bytes);
1568            }
1569            None => {
1570                let mut entries = BTreeMap::new();
1571                entries.insert(voice_file, Arc::from(bytes.into()));
1572                self.files.voices_dir = Some(ModelAssetDir::from_bytes(entries));
1573            }
1574        }
1575        self
1576    }
1577
1578    /// Append a single **speech tokenizer weight file** (Qwen3-TTS only).
1579    ///
1580    /// These weights belong to the separate speech tokenizer decoder
1581    /// model (`Qwen/Qwen3-TTS-Tokenizer-12Hz`) that converts discrete
1582    /// codec tokens into a continuous 24 kHz audio waveform.
1583    pub fn with_speech_tokenizer_weight_file(mut self, path: impl Into<PathBuf>) -> Self {
1584        self.files
1585            .speech_tokenizer_weights
1586            .push(ModelAsset::from_path(path.into()));
1587        self
1588    }
1589
1590    /// Append a single in-memory speech-tokenizer weight file.
1591    pub fn with_speech_tokenizer_weight_bytes(
1592        mut self,
1593        file_name: impl Into<String>,
1594        bytes: impl Into<Vec<u8>>,
1595    ) -> Self {
1596        self.files
1597            .speech_tokenizer_weights
1598            .push(ModelAsset::from_bytes(file_name.into(), bytes));
1599        self
1600    }
1601
1602    /// Set **all speech tokenizer weight files** at once (Qwen3-TTS only).
1603    pub fn with_speech_tokenizer_weight_files(mut self, paths: Vec<PathBuf>) -> Self {
1604        self.files.speech_tokenizer_weights =
1605            paths.into_iter().map(ModelAsset::from_path).collect();
1606        self
1607    }
1608
1609    /// Set the **speech tokenizer config** file (Qwen3-TTS only).
1610    ///
1611    /// JSON config for the speech tokenizer decoder model, including
1612    /// codebook dimensions, upsampling ratios, and activation parameters.
1613    pub fn with_speech_tokenizer_config_file(mut self, path: impl Into<PathBuf>) -> Self {
1614        self.files.speech_tokenizer_config = Some(ModelAsset::from_path(path.into()));
1615        self
1616    }
1617
1618    /// Set the speech-tokenizer config from in-memory bytes.
1619    pub fn with_speech_tokenizer_config_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1620        self.files.speech_tokenizer_config = Some(ModelAsset::from_bytes(
1621            "speech_tokenizer/config.json",
1622            bytes,
1623        ));
1624        self
1625    }
1626
1627    /// Set the **generation config** file (optional).
1628    ///
1629    /// Standard HuggingFace `generation_config.json` with parameters
1630    /// like `max_new_tokens`, `top_p`, `temperature`, `do_sample`, etc.
1631    pub fn with_generation_config_file(mut self, path: impl Into<PathBuf>) -> Self {
1632        self.files.generation_config = Some(ModelAsset::from_path(path.into()));
1633        self
1634    }
1635
1636    /// Set `generation_config.json` from in-memory bytes.
1637    pub fn with_generation_config_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1638        self.files.generation_config =
1639            Some(ModelAsset::from_bytes("generation_config.json", bytes));
1640        self
1641    }
1642
1643    /// Set the **preprocessor config** file (optional).
1644    ///
1645    /// This stores published preprocessing defaults such as audio
1646    /// normalization parameters and speech token compression ratios.
1647    pub fn with_preprocessor_config_file(mut self, path: impl Into<PathBuf>) -> Self {
1648        self.files.preprocessor_config = Some(ModelAsset::from_path(path.into()));
1649        self
1650    }
1651
1652    /// Set `preprocessor_config.json` from in-memory bytes.
1653    pub fn with_preprocessor_config_bytes(mut self, bytes: impl Into<Vec<u8>>) -> Self {
1654        self.files.preprocessor_config =
1655            Some(ModelAsset::from_bytes("preprocessor_config.json", bytes));
1656        self
1657    }
1658
1659    // ── Resolution ────────────────────────────────────────────────────
1660
1661    /// Resolve all model files using the four-tier strategy:
1662    ///
1663    /// 1. Explicit paths already set on `self.files`
1664    /// 2. Auto-discovery from `self.asset_bundle`
1665    /// 3. Auto-discovery from `self.model_path` directory
1666    /// 4. HuggingFace Hub download (if `download` feature enabled)
1667    ///
1668    /// Returns a fully populated [`ModelFiles`] or an error listing
1669    /// which files are missing.
1670    pub fn resolve_files(&self) -> Result<ModelFiles, TtsError> {
1671        let mut files = self.files.clone();
1672
1673        if !self.asset_bundle.is_empty() {
1674            files.fill_from_asset_bundle(&self.asset_bundle);
1675        }
1676
1677        // Tier 3: fill from model_path directory
1678        if let Some(ref dir) = self.model_path {
1679            files.fill_from_directory(Path::new(dir));
1680        }
1681
1682        // Tier 4: HuggingFace Hub download fallback
1683        #[cfg(feature = "download")]
1684        {
1685            if !files.missing_files(self.model_type).is_empty() {
1686                let hf_id = self.effective_hf_model_id();
1687                info!("Downloading missing files from HuggingFace: {}", hf_id);
1688                files.fill_from_hf(hf_id, self.model_type, self.bearer_token.as_deref())?;
1689            }
1690        }
1691
1692        // Validate completeness
1693        files.validate(self.model_type)?;
1694
1695        Ok(files)
1696    }
1697
1698    /// Get the default HuggingFace model ID for this model type.
1699    pub fn default_hf_model_id(&self) -> &str {
1700        match self.model_type {
1701            ModelType::Kokoro => "hexgrad/Kokoro-82M",
1702            ModelType::OmniVoice => "k2-fsa/OmniVoice",
1703            ModelType::Qwen3Tts => "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
1704            ModelType::VibeVoice => "microsoft/VibeVoice-1.5B",
1705            ModelType::VibeVoiceRealtime => "microsoft/VibeVoice-Realtime-0.5B",
1706            ModelType::Voxtral => "mistralai/Voxtral-4B-TTS-2603",
1707        }
1708    }
1709
1710    /// Resolve the effective HuggingFace model ID.
1711    pub fn effective_hf_model_id(&self) -> &str {
1712        self.hf_model_id
1713            .as_deref()
1714            .unwrap_or_else(|| self.default_hf_model_id())
1715    }
1716
1717    /// Resolve the model reference to forward to an external runtime.
1718    pub fn effective_model_ref(&self) -> &str {
1719        self.model_path
1720            .as_deref()
1721            .unwrap_or_else(|| self.effective_hf_model_id())
1722    }
1723
1724    /// Get the default runtime command for the configured model, if any.
1725    pub fn default_runtime_command(&self) -> Option<&str> {
1726        match self.model_type {
1727            ModelType::Voxtral => Some("python3"),
1728            ModelType::Kokoro
1729            | ModelType::OmniVoice
1730            | ModelType::Qwen3Tts
1731            | ModelType::VibeVoice
1732            | ModelType::VibeVoiceRealtime => None,
1733        }
1734    }
1735
1736    /// Resolve the effective runtime command for adapter backends.
1737    pub fn effective_runtime_command(&self) -> Option<&str> {
1738        self.runtime_command
1739            .as_deref()
1740            .or_else(|| self.default_runtime_command())
1741    }
1742
1743    /// Get the default runtime endpoint for the configured model, if any.
1744    pub fn default_runtime_endpoint(&self) -> Option<&str> {
1745        match self.model_type {
1746            ModelType::Kokoro
1747            | ModelType::OmniVoice
1748            | ModelType::Qwen3Tts
1749            | ModelType::VibeVoice
1750            | ModelType::VibeVoiceRealtime
1751            | ModelType::Voxtral => None,
1752        }
1753    }
1754
1755    /// Resolve the effective runtime endpoint for adapter backends.
1756    pub fn effective_runtime_endpoint(&self) -> Option<&str> {
1757        self.runtime_endpoint
1758            .as_deref()
1759            .or_else(|| self.default_runtime_endpoint())
1760    }
1761
1762    /// Resolve the bearer token used by external HTTP runtimes.
1763    pub fn effective_bearer_token(&self) -> &str {
1764        self.bearer_token.as_deref().unwrap_or("EMPTY")
1765    }
1766}
1767
1768#[cfg(test)]
1769mod tests {
1770    use super::*;
1771
1772    #[test]
1773    fn test_dtype_labels_are_stable() {
1774        assert_eq!(DType::F32.label(), "f32");
1775        assert_eq!(DType::F16.label(), "f16");
1776        assert_eq!(DType::BF16.label(), "bf16");
1777    }
1778
1779    #[test]
1780    fn test_kokoro_metal_prefers_f32() {
1781        assert_eq!(
1782            preferred_dtype_for(ModelType::Kokoro, DeviceSelection::Metal(0)),
1783            DType::F32
1784        );
1785    }
1786
1787    #[test]
1788    fn test_qwen3_metal_prefers_bf16() {
1789        assert_eq!(
1790            preferred_dtype_for(ModelType::Qwen3Tts, DeviceSelection::Metal(0)),
1791            DType::BF16
1792        );
1793    }
1794
1795    #[test]
1796    fn test_omnivoice_metal_prefers_f32() {
1797        let choice = RuntimeChoice {
1798            device: DeviceSelection::Metal(0),
1799            dtype: preferred_dtype_for(ModelType::OmniVoice, DeviceSelection::Metal(0)),
1800        };
1801        assert_eq!(choice.label(), "metal:0 (f32)");
1802    }
1803
1804    #[test]
1805    fn test_with_preferred_runtime_applies_choice() {
1806        let expected = preferred_runtime_choice(ModelType::VibeVoice);
1807        let config = TtsConfig::new(ModelType::VibeVoice).with_preferred_runtime();
1808        assert_eq!(config.device, expected.device);
1809        assert_eq!(config.dtype, expected.dtype);
1810    }
1811
1812    #[test]
1813    fn test_resolve_files_from_in_memory_omnivoice_assets() {
1814        let bundle = ModelAssetBundle::new()
1815            .with_bytes("config.json", vec![1])
1816            .with_bytes("tokenizer.json", vec![2])
1817            .with_bytes("model.safetensors", vec![3])
1818            .with_bytes("audio_tokenizer/config.json", vec![4])
1819            .with_bytes("audio_tokenizer/model.safetensors", vec![5]);
1820
1821        let files = TtsConfig::new(ModelType::OmniVoice)
1822            .with_asset_bundle(bundle)
1823            .resolve_files()
1824            .unwrap();
1825
1826        assert!(matches!(files.config, Some(ModelAsset::Bytes { .. })));
1827        assert!(matches!(files.tokenizer, Some(ModelAsset::Bytes { .. })));
1828        assert_eq!(files.weights.len(), 1);
1829        assert_eq!(files.speech_tokenizer_weights.len(), 1);
1830    }
1831
1832    #[test]
1833    fn test_with_voice_bytes_creates_in_memory_voice_dir() {
1834        let config = TtsConfig::new(ModelType::Kokoro).with_voice_bytes("af_heart", vec![1, 2]);
1835        let voices_dir = config.files.voices_dir.as_ref().unwrap();
1836        assert_eq!(voices_dir.file_names().unwrap(), vec!["af_heart.pt"]);
1837    }
1838
1839    #[test]
1840    fn test_model_asset_manifest_is_available() {
1841        let requirements = ModelType::Voxtral.asset_requirements();
1842        assert!(!requirements.is_empty());
1843        assert!(requirements
1844            .iter()
1845            .any(|entry| entry.pattern == "voice_embedding/*.pt"));
1846    }
1847}