Skip to main content

rlx_voxtral/
config.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Voxtral configuration — HuggingFace `config.json` (`mistralai/Voxtral-Mini-3B-2507`).
17
18use anyhow::{Context, Result, ensure};
19use rlx_llama32::Llama32Config;
20use serde::Deserialize;
21use std::path::Path;
22
23/// Whisper-style audio encoder section (`audio_config` in HF JSON).
24#[derive(Debug, Clone, Deserialize)]
25pub struct VoxtralAudioConfig {
26    pub num_mel_bins: usize,
27    pub max_source_positions: usize,
28    #[serde(rename = "hidden_size", alias = "d_model")]
29    pub d_model: usize,
30    #[serde(rename = "num_attention_heads", alias = "encoder_attention_heads")]
31    pub encoder_attention_heads: usize,
32    #[serde(rename = "num_hidden_layers", alias = "encoder_layers")]
33    pub encoder_layers: usize,
34    pub intermediate_size: usize,
35    #[serde(default)]
36    pub scale_embedding: bool,
37}
38
39impl VoxtralAudioConfig {
40    pub fn head_dim(&self) -> usize {
41        self.d_model / self.encoder_attention_heads
42    }
43
44    /// Sequence length after two stride convolutions (same as Whisper).
45    pub fn encoder_seq_len(&self, mel_frames: usize) -> usize {
46        let after_conv1 = mel_frames;
47        let pad = 1usize;
48        let k = 3usize;
49        let stride2 = 2usize;
50        (after_conv1 + 2 * pad - k) / stride2 + 1
51    }
52
53    /// Audio frames after the 4× projector grouping.
54    pub fn audio_token_count(&self, mel_frames: usize) -> usize {
55        self.encoder_seq_len(mel_frames) / 4
56    }
57
58    pub fn tiny_synthetic() -> Self {
59        Self {
60            num_mel_bins: 4,
61            max_source_positions: 16,
62            d_model: 8,
63            encoder_attention_heads: 2,
64            encoder_layers: 1,
65            intermediate_size: 32,
66            scale_embedding: false,
67        }
68    }
69
70    pub fn mini_3b() -> Self {
71        Self {
72            num_mel_bins: 128,
73            max_source_positions: 1500,
74            d_model: 1280,
75            encoder_attention_heads: 20,
76            encoder_layers: 32,
77            intermediate_size: 5120,
78            scale_embedding: false,
79        }
80    }
81}
82
83/// Top-level Voxtral checkpoint config.
84#[derive(Debug, Clone, Deserialize)]
85pub struct VoxtralConfig {
86    pub audio_config: VoxtralAudioConfig,
87    pub text_config: Llama32Config,
88    #[serde(default = "default_audio_token_id")]
89    pub audio_token_id: u32,
90    #[serde(default = "default_projector_act")]
91    pub projector_hidden_act: String,
92    pub vocab_size: usize,
93}
94
95fn default_audio_token_id() -> u32 {
96    24
97}
98
99fn default_projector_act() -> String {
100    "gelu".into()
101}
102
103impl VoxtralConfig {
104    pub fn from_file(path: &Path) -> Result<Self> {
105        let data = std::fs::read_to_string(path)?;
106        serde_json::from_str(&data).with_context(|| format!("parse Voxtral config {path:?}"))
107    }
108
109    pub fn llama_config(&self) -> &Llama32Config {
110        &self.text_config
111    }
112
113    pub fn validate(&self) -> Result<()> {
114        ensure!(
115            self.text_config.hidden_size > 0,
116            "text_config.hidden_size must be > 0"
117        );
118        ensure!(
119            self.audio_config.intermediate_size == self.audio_config.d_model * 4,
120            "audio_config.intermediate_size should be 4× d_model for the projector reshape"
121        );
122        Ok(())
123    }
124
125    pub fn tiny_synthetic() -> Self {
126        Self {
127            audio_config: VoxtralAudioConfig::tiny_synthetic(),
128            text_config: Llama32Config {
129                vocab_size: 32,
130                hidden_size: 16,
131                intermediate_size: 32,
132                num_hidden_layers: 1,
133                num_attention_heads: 4,
134                num_key_value_heads: 2,
135                max_position_embeddings: 16,
136                rms_norm_eps: 1e-5,
137                rope_theta: 100_000_000.0,
138                hidden_act: "silu".into(),
139                tie_word_embeddings: true,
140                attention_bias: false,
141                head_dim: Some(4),
142                rope_scaling: None,
143            },
144            audio_token_id: 24,
145            projector_hidden_act: "gelu".into(),
146            vocab_size: 32,
147        }
148    }
149
150    pub fn mini_3b() -> Self {
151        Self {
152            audio_config: VoxtralAudioConfig::mini_3b(),
153            text_config: Llama32Config {
154                vocab_size: 131_072,
155                hidden_size: 3072,
156                intermediate_size: 8192,
157                num_hidden_layers: 30,
158                num_attention_heads: 32,
159                num_key_value_heads: 8,
160                max_position_embeddings: 131_072,
161                rms_norm_eps: 1e-5,
162                rope_theta: 100_000_000.0,
163                hidden_act: "silu".into(),
164                tie_word_embeddings: true,
165                attention_bias: false,
166                head_dim: Some(128),
167                rope_scaling: None,
168            },
169            audio_token_id: 24,
170            projector_hidden_act: "gelu".into(),
171            vocab_size: 131_072,
172        }
173    }
174}