1use serde::Deserialize;
19use std::path::Path;
20
21#[derive(Debug, Clone, Deserialize)]
23pub struct WhisperConfig {
24 pub num_mel_bins: usize,
25 pub max_source_positions: usize,
26 pub d_model: usize,
27 pub encoder_attention_heads: usize,
28 pub encoder_layers: usize,
29 pub vocab_size: usize,
30 pub max_target_positions: usize,
31 pub decoder_attention_heads: usize,
32 pub decoder_layers: usize,
33 #[serde(default)]
34 pub suppress_tokens: Vec<u32>,
35 #[serde(default)]
37 pub begin_suppress_tokens: Vec<u32>,
38}
39
40impl WhisperConfig {
41 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
42 let data = std::fs::read_to_string(path)?;
43 Ok(serde_json::from_str(&data)?)
44 }
45
46 pub fn head_dim(&self) -> usize {
47 self.d_model / self.encoder_attention_heads
48 }
49
50 pub fn decoder_head_dim(&self) -> usize {
51 self.d_model / self.decoder_attention_heads
52 }
53
54 pub fn encoder_seq_len(&self, mel_frames: usize) -> usize {
56 let after_conv1 = mel_frames;
57 let pad = 1usize;
58 let k = 3usize;
59 let stride2 = 2usize;
60 (after_conv1 + 2 * pad - k) / stride2 + 1
61 }
62
63 pub fn tiny_synthetic() -> Self {
65 Self {
66 num_mel_bins: 4,
67 max_source_positions: 16,
68 d_model: 8,
69 encoder_attention_heads: 2,
70 encoder_layers: 1,
71 vocab_size: 32,
72 max_target_positions: 16,
73 decoder_attention_heads: 2,
74 decoder_layers: 1,
75 suppress_tokens: vec![],
76 begin_suppress_tokens: vec![],
77 }
78 }
79
80 pub fn tiny() -> Self {
82 Self {
83 num_mel_bins: 80,
84 max_source_positions: 1500,
85 d_model: 384,
86 encoder_attention_heads: 6,
87 encoder_layers: 4,
88 vocab_size: 51865,
89 max_target_positions: 448,
90 decoder_attention_heads: 6,
91 decoder_layers: 4,
92 suppress_tokens: vec![],
93 begin_suppress_tokens: vec![220, 50257],
94 }
95 }
96}