Skip to main content

rlx_whisper/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Whisper configuration — HuggingFace `config.json` / OpenAI dims.
17
18use serde::Deserialize;
19use std::path::Path;
20
21/// Whisper model dimensions (HF field names; see OpenAI `model.py` comments in candle).
22#[derive(Debug, Clone, Deserialize)]
23pub struct WhisperConfig {
24    pub num_mel_bins: usize,
25    pub max_source_positions: usize,
26    pub d_model: usize,
27    pub encoder_attention_heads: usize,
28    pub encoder_layers: usize,
29    pub vocab_size: usize,
30    pub max_target_positions: usize,
31    pub decoder_attention_heads: usize,
32    pub decoder_layers: usize,
33    #[serde(default)]
34    pub suppress_tokens: Vec<u32>,
35    /// Suppressed on the first generated token only (HF `begin_suppress_tokens`).
36    #[serde(default)]
37    pub begin_suppress_tokens: Vec<u32>,
38}
39
40impl WhisperConfig {
41    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
42        let data = std::fs::read_to_string(path)?;
43        Ok(serde_json::from_str(&data)?)
44    }
45
46    pub fn head_dim(&self) -> usize {
47        self.d_model / self.encoder_attention_heads
48    }
49
50    pub fn decoder_head_dim(&self) -> usize {
51        self.d_model / self.decoder_attention_heads
52    }
53
54    /// Encoder sequence length after the two conv layers (stride-2 on the second).
55    pub fn encoder_seq_len(&self, mel_frames: usize) -> usize {
56        let after_conv1 = mel_frames;
57        let pad = 1usize;
58        let k = 3usize;
59        let stride2 = 2usize;
60        (after_conv1 + 2 * pad - k) / stride2 + 1
61    }
62
63    /// Tiny config for compile tests (no real weights required).
64    pub fn tiny_synthetic() -> Self {
65        Self {
66            num_mel_bins: 4,
67            max_source_positions: 16,
68            d_model: 8,
69            encoder_attention_heads: 2,
70            encoder_layers: 1,
71            vocab_size: 32,
72            max_target_positions: 16,
73            decoder_attention_heads: 2,
74            decoder_layers: 1,
75            suppress_tokens: vec![],
76            begin_suppress_tokens: vec![],
77        }
78    }
79
80    /// `openai/whisper-tiny` dimensions.
81    pub fn tiny() -> Self {
82        Self {
83            num_mel_bins: 80,
84            max_source_positions: 1500,
85            d_model: 384,
86            encoder_attention_heads: 6,
87            encoder_layers: 4,
88            vocab_size: 51865,
89            max_target_positions: 448,
90            decoder_attention_heads: 6,
91            decoder_layers: 4,
92            suppress_tokens: vec![],
93            begin_suppress_tokens: vec![220, 50257],
94        }
95    }
96}