mistralrs_core/speech_models/
mod.rs1mod bs1770;
2mod dia;
3pub mod utils;
4
5use std::{str::FromStr, sync::Arc};
6
7pub use dia::{DiaConfig, DiaPipeline};
8use serde::{Deserialize, Serialize};
9
10#[derive(Clone, Copy, Debug, Deserialize, Serialize, PartialEq)]
11pub enum SpeechLoaderType {
12 #[serde(rename = "dia")]
13 Dia,
14}
15
16impl FromStr for SpeechLoaderType {
17 type Err = String;
18 fn from_str(s: &str) -> Result<Self, Self::Err> {
19 match s {
20 "dia" => Ok(Self::Dia),
21 a => Err(format!(
22 "Unknown architecture `{a}`. Possible architectures: `dia`."
23 )),
24 }
25 }
26}
27
28impl SpeechLoaderType {
29 pub fn auto_detect_from_config(config: &str) -> Option<Self> {
32 if serde_json::from_str::<DiaConfig>(config).is_ok() {
33 return Some(Self::Dia);
34 }
35 None
36 }
37}
38
39#[derive(Clone, Copy, Debug)]
40pub enum SpeechGenerationConfig {
41 Dia {
42 max_tokens: Option<usize>,
43 cfg_scale: f32,
44 temperature: f32,
45 top_p: f32,
46 top_k: Option<usize>,
47 },
48}
49
50impl SpeechGenerationConfig {
51 pub fn default(ty: SpeechLoaderType) -> Self {
52 match ty {
53 SpeechLoaderType::Dia => Self::Dia {
54 max_tokens: None,
55 cfg_scale: 3.,
56 temperature: 1.3,
57 top_p: 0.95,
58 top_k: Some(35),
59 },
60 }
61 }
62}
63
64#[derive(Clone, Debug)]
65pub struct SpeechGenerationOutput {
66 pub pcm: Arc<Vec<f32>>,
67 pub rate: usize,
68 pub channels: usize,
69}