Skip to main content

rlx_whisper/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Whisper configuration — HuggingFace `config.json` / OpenAI dims.
17
18use serde::Deserialize;
19use std::path::Path;
20
21/// Whisper model dimensions (HF field names; see OpenAI `model.py` comments in candle).
22#[derive(Debug, Clone, Deserialize)]
23pub struct WhisperConfig {
24    pub num_mel_bins: usize,
25    pub max_source_positions: usize,
26    pub d_model: usize,
27    pub encoder_attention_heads: usize,
28    pub encoder_layers: usize,
29    pub vocab_size: usize,
30    pub max_target_positions: usize,
31    pub decoder_attention_heads: usize,
32    pub decoder_layers: usize,
33    #[serde(default)]
34    pub suppress_tokens: Vec<u32>,
35}
36
37impl WhisperConfig {
38    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
39        let data = std::fs::read_to_string(path)?;
40        Ok(serde_json::from_str(&data)?)
41    }
42
43    pub fn head_dim(&self) -> usize {
44        self.d_model / self.encoder_attention_heads
45    }
46
47    pub fn decoder_head_dim(&self) -> usize {
48        self.d_model / self.decoder_attention_heads
49    }
50
51    /// Encoder sequence length after the two conv layers (stride-2 on the second).
52    pub fn encoder_seq_len(&self, mel_frames: usize) -> usize {
53        let after_conv1 = mel_frames;
54        let pad = 1usize;
55        let k = 3usize;
56        let stride2 = 2usize;
57        (after_conv1 + 2 * pad - k) / stride2 + 1
58    }
59
60    /// Tiny config for compile tests (no real weights required).
61    pub fn tiny_synthetic() -> Self {
62        Self {
63            num_mel_bins: 4,
64            max_source_positions: 16,
65            d_model: 8,
66            encoder_attention_heads: 2,
67            encoder_layers: 1,
68            vocab_size: 32,
69            max_target_positions: 16,
70            decoder_attention_heads: 2,
71            decoder_layers: 1,
72            suppress_tokens: vec![],
73        }
74    }
75
76    /// `openai/whisper-tiny` dimensions.
77    pub fn tiny() -> Self {
78        Self {
79            num_mel_bins: 80,
80            max_source_positions: 1500,
81            d_model: 384,
82            encoder_attention_heads: 6,
83            encoder_layers: 4,
84            vocab_size: 51865,
85            max_target_positions: 448,
86            decoder_attention_heads: 6,
87            decoder_layers: 4,
88            suppress_tokens: vec![],
89        }
90    }
91}