1use serde::Deserialize;
19use std::path::Path;
20
21#[derive(Debug, Clone, Deserialize)]
23pub struct WhisperConfig {
24 pub num_mel_bins: usize,
25 pub max_source_positions: usize,
26 pub d_model: usize,
27 pub encoder_attention_heads: usize,
28 pub encoder_layers: usize,
29 pub vocab_size: usize,
30 pub max_target_positions: usize,
31 pub decoder_attention_heads: usize,
32 pub decoder_layers: usize,
33 #[serde(default)]
34 pub suppress_tokens: Vec<u32>,
35}
36
37impl WhisperConfig {
38 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
39 let data = std::fs::read_to_string(path)?;
40 Ok(serde_json::from_str(&data)?)
41 }
42
43 pub fn head_dim(&self) -> usize {
44 self.d_model / self.encoder_attention_heads
45 }
46
47 pub fn decoder_head_dim(&self) -> usize {
48 self.d_model / self.decoder_attention_heads
49 }
50
51 pub fn encoder_seq_len(&self, mel_frames: usize) -> usize {
53 let after_conv1 = mel_frames;
54 let pad = 1usize;
55 let k = 3usize;
56 let stride2 = 2usize;
57 (after_conv1 + 2 * pad - k) / stride2 + 1
58 }
59
60 pub fn tiny_synthetic() -> Self {
62 Self {
63 num_mel_bins: 4,
64 max_source_positions: 16,
65 d_model: 8,
66 encoder_attention_heads: 2,
67 encoder_layers: 1,
68 vocab_size: 32,
69 max_target_positions: 16,
70 decoder_attention_heads: 2,
71 decoder_layers: 1,
72 suppress_tokens: vec![],
73 }
74 }
75
76 pub fn tiny() -> Self {
78 Self {
79 num_mel_bins: 80,
80 max_source_positions: 1500,
81 d_model: 384,
82 encoder_attention_heads: 6,
83 encoder_layers: 4,
84 vocab_size: 51865,
85 max_target_positions: 448,
86 decoder_attention_heads: 6,
87 decoder_layers: 4,
88 suppress_tokens: vec![],
89 }
90 }
91}