1use anyhow::{Context, Result, ensure};
19use rlx_llama32::Llama32Config;
20use serde::Deserialize;
21use std::path::Path;
22
23#[derive(Debug, Clone, Deserialize)]
25pub struct VoxtralAudioConfig {
26 pub num_mel_bins: usize,
27 pub max_source_positions: usize,
28 #[serde(rename = "hidden_size", alias = "d_model")]
29 pub d_model: usize,
30 #[serde(rename = "num_attention_heads", alias = "encoder_attention_heads")]
31 pub encoder_attention_heads: usize,
32 #[serde(rename = "num_hidden_layers", alias = "encoder_layers")]
33 pub encoder_layers: usize,
34 pub intermediate_size: usize,
35 #[serde(default)]
36 pub scale_embedding: bool,
37}
38
39impl VoxtralAudioConfig {
40 pub fn head_dim(&self) -> usize {
41 self.d_model / self.encoder_attention_heads
42 }
43
44 pub fn encoder_seq_len(&self, mel_frames: usize) -> usize {
46 let after_conv1 = mel_frames;
47 let pad = 1usize;
48 let k = 3usize;
49 let stride2 = 2usize;
50 (after_conv1 + 2 * pad - k) / stride2 + 1
51 }
52
53 pub fn audio_token_count(&self, mel_frames: usize) -> usize {
55 self.encoder_seq_len(mel_frames) / 4
56 }
57
58 pub fn tiny_synthetic() -> Self {
59 Self {
60 num_mel_bins: 4,
61 max_source_positions: 16,
62 d_model: 8,
63 encoder_attention_heads: 2,
64 encoder_layers: 1,
65 intermediate_size: 32,
66 scale_embedding: false,
67 }
68 }
69
70 pub fn mini_3b() -> Self {
71 Self {
72 num_mel_bins: 128,
73 max_source_positions: 1500,
74 d_model: 1280,
75 encoder_attention_heads: 20,
76 encoder_layers: 32,
77 intermediate_size: 5120,
78 scale_embedding: false,
79 }
80 }
81}
82
83#[derive(Debug, Clone, Deserialize)]
85pub struct VoxtralConfig {
86 pub audio_config: VoxtralAudioConfig,
87 pub text_config: Llama32Config,
88 #[serde(default = "default_audio_token_id")]
89 pub audio_token_id: u32,
90 #[serde(default = "default_projector_act")]
91 pub projector_hidden_act: String,
92 pub vocab_size: usize,
93}
94
95fn default_audio_token_id() -> u32 {
96 24
97}
98
99fn default_projector_act() -> String {
100 "gelu".into()
101}
102
103impl VoxtralConfig {
104 pub fn from_file(path: &Path) -> Result<Self> {
105 let data = std::fs::read_to_string(path)?;
106 serde_json::from_str(&data).with_context(|| format!("parse Voxtral config {path:?}"))
107 }
108
109 pub fn llama_config(&self) -> &Llama32Config {
110 &self.text_config
111 }
112
113 pub fn validate(&self) -> Result<()> {
114 ensure!(
115 self.text_config.hidden_size > 0,
116 "text_config.hidden_size must be > 0"
117 );
118 ensure!(
119 self.audio_config.intermediate_size == self.audio_config.d_model * 4,
120 "audio_config.intermediate_size should be 4× d_model for the projector reshape"
121 );
122 Ok(())
123 }
124
125 pub fn tiny_synthetic() -> Self {
126 Self {
127 audio_config: VoxtralAudioConfig::tiny_synthetic(),
128 text_config: Llama32Config {
129 vocab_size: 32,
130 hidden_size: 16,
131 intermediate_size: 32,
132 num_hidden_layers: 1,
133 num_attention_heads: 4,
134 num_key_value_heads: 2,
135 max_position_embeddings: 16,
136 rms_norm_eps: 1e-5,
137 rope_theta: 100_000_000.0,
138 hidden_act: "silu".into(),
139 tie_word_embeddings: true,
140 attention_bias: false,
141 head_dim: Some(4),
142 rope_scaling: None,
143 },
144 audio_token_id: 24,
145 projector_hidden_act: "gelu".into(),
146 vocab_size: 32,
147 }
148 }
149
150 pub fn mini_3b() -> Self {
151 Self {
152 audio_config: VoxtralAudioConfig::mini_3b(),
153 text_config: Llama32Config {
154 vocab_size: 131_072,
155 hidden_size: 3072,
156 intermediate_size: 8192,
157 num_hidden_layers: 30,
158 num_attention_heads: 32,
159 num_key_value_heads: 8,
160 max_position_embeddings: 131_072,
161 rms_norm_eps: 1e-5,
162 rope_theta: 100_000_000.0,
163 hidden_act: "silu".into(),
164 tie_word_embeddings: true,
165 attention_bias: false,
166 head_dim: Some(128),
167 rope_scaling: None,
168 },
169 audio_token_id: 24,
170 projector_hidden_act: "gelu".into(),
171 vocab_size: 131_072,
172 }
173 }
174}