Skip to main content

mistralrs_core/speech_models/
mod.rs

1mod bs1770;
2mod dia;
3pub mod utils;
4
5use std::{str::FromStr, sync::Arc};
6
7pub use dia::{DiaConfig, DiaPipeline};
8use serde::{Deserialize, Serialize};
9
10#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq)]
11pub enum SpeechLoaderType {
12    #[serde(rename = "dia")]
13    Dia,
14}
15
16impl FromStr for SpeechLoaderType {
17    type Err = String;
18    fn from_str(s: &str) -> Result<Self, Self::Err> {
19        match s {
20            "dia" => Ok(Self::Dia),
21            a => Err(format!(
22                "Unknown architecture `{a}`. Possible architectures: `dia`."
23            )),
24        }
25    }
26}
27
28impl SpeechLoaderType {
29    /// Auto-detect speech loader type from a config.json string.
30    /// Extend this when adding new speech pipelines.
31    pub fn auto_detect_from_config(config: &str) -> Option<Self> {
32        if serde_json::from_str::<DiaConfig>(config).is_ok() {
33            return Some(Self::Dia);
34        }
35        None
36    }
37}
38
39#[derive(Clone, Copy, Debug)]
40pub enum SpeechGenerationConfig {
41    Dia {
42        max_tokens: Option<usize>,
43        cfg_scale: f32,
44        temperature: f32,
45        top_p: f32,
46        top_k: Option<usize>,
47    },
48}
49
50impl SpeechGenerationConfig {
51    pub fn default(ty: SpeechLoaderType) -> Self {
52        match ty {
53            SpeechLoaderType::Dia => Self::Dia {
54                max_tokens: None,
55                cfg_scale: 3.,
56                temperature: 1.3,
57                top_p: 0.95,
58                top_k: Some(35),
59            },
60        }
61    }
62}
63
64#[derive(Clone, Debug)]
65pub struct SpeechGenerationOutput {
66    pub pcm: Arc<Vec<f32>>,
67    pub rate: usize,
68    pub channels: usize,
69}