Skip to main content

rlx_voxtral_tts/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Map Voxtral-4B-TTS checkpoint keys → Llama builder keys.
17
18use rlx_core::weight_loader::WeightLoader;
19use std::collections::HashMap;
20
21/// Maps HF/Llama flow keys (`model.layers.*.input_layernorm`, …) to Mistral
22/// consolidated names (`layers.*.attention_norm`, `feed_forward`, …).
23pub struct BackbonePrefixLoader<'a> {
24    inner: &'a mut dyn WeightLoader,
25}
26
27impl<'a> BackbonePrefixLoader<'a> {
28    pub fn new(inner: &'a mut dyn WeightLoader) -> Self {
29        Self { inner }
30    }
31
32    pub fn map_key(key: &str) -> String {
33        match key {
34            "model.norm.weight" => "norm.weight".into(),
35            "model.embed_tokens.weight" => "tok_embeddings.weight".into(),
36            "lm_head.weight" => "output.weight".into(),
37            k if k.starts_with("model.layers.") => map_layer_key(k),
38            k if k.starts_with("layers.") => k.to_string(),
39            other => other.to_string(),
40        }
41    }
42}
43
44fn map_layer_key(key: &str) -> String {
45    let rest = key.strip_prefix("model.layers.").unwrap_or(key);
46    let Some(dot) = rest.find('.') else {
47        return key.to_string();
48    };
49    let (idx, tail) = rest.split_at(dot);
50    let tail = &tail[1..];
51    let mapped = match tail {
52        "input_layernorm.weight" => "attention_norm.weight",
53        "post_attention_layernorm.weight" => "ffn_norm.weight",
54        "self_attn.q_proj.weight" => "attention.wq.weight",
55        "self_attn.k_proj.weight" => "attention.wk.weight",
56        "self_attn.v_proj.weight" => "attention.wv.weight",
57        "self_attn.o_proj.weight" => "attention.wo.weight",
58        "mlp.gate_proj.weight" => "feed_forward.w1.weight",
59        "mlp.up_proj.weight" => "feed_forward.w3.weight",
60        "mlp.down_proj.weight" => "feed_forward.w2.weight",
61        "gate_proj.weight" => "feed_forward.w1.weight",
62        "up_proj.weight" => "feed_forward.w3.weight",
63        "down_proj.weight" => "feed_forward.w2.weight",
64        other => other,
65    };
66    format!("layers.{idx}.{mapped}")
67}
68
69impl WeightLoader for BackbonePrefixLoader<'_> {
70    fn len(&self) -> usize {
71        self.inner.len()
72    }
73
74    fn take(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
75        self.inner.take(&Self::map_key(key))
76    }
77
78    fn take_transposed(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
79        self.inner.take_transposed(&Self::map_key(key))
80    }
81
82    fn remaining_keys(&self) -> Vec<String> {
83        self.inner.remaining_keys()
84    }
85}
86
87/// Serves checkpoint tensors by name (used to rebuild graphs without reloading safetensors).
88pub struct CheckpointParamLoader {
89    params: HashMap<String, (Vec<f32>, Vec<usize>)>,
90}
91
92impl CheckpointParamLoader {
93    pub fn new(params: HashMap<String, (Vec<f32>, Vec<usize>)>) -> Self {
94        Self { params }
95    }
96}
97
98impl WeightLoader for CheckpointParamLoader {
99    fn len(&self) -> usize {
100        self.params.len()
101    }
102
103    fn take(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
104        self.params
105            .get(key)
106            .cloned()
107            .ok_or_else(|| anyhow::anyhow!("missing weight {key}"))
108    }
109
110    fn take_transposed(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
111        self.take(key)
112    }
113
114    fn remaining_keys(&self) -> Vec<String> {
115        self.params.keys().cloned().collect()
116    }
117}
118
119/// Maps flow keys (`layers.*`) → checkpoint keys (`acoustic_transformer.*`).
120pub struct AcousticPrefixLoader<'a> {
121    inner: &'a mut rlx_core::weight_map::WeightMap,
122}
123
124impl<'a> AcousticPrefixLoader<'a> {
125    pub fn new(inner: &'a mut rlx_core::weight_map::WeightMap) -> Self {
126        Self { inner }
127    }
128
129    fn full_key(key: &str) -> String {
130        if key.starts_with(crate::load::PREFIX_ACOUSTIC) {
131            key.to_string()
132        } else {
133            format!("{}{key}", crate::load::PREFIX_ACOUSTIC)
134        }
135    }
136}
137
138impl WeightLoader for AcousticPrefixLoader<'_> {
139    fn len(&self) -> usize {
140        self.inner.len()
141    }
142
143    fn take(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
144        self.inner.take(&Self::full_key(key))
145    }
146
147    fn take_transposed(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
148        self.inner.take_transposed(&Self::full_key(key))
149    }
150
151    fn remaining_keys(&self) -> Vec<String> {
152        self.inner
153            .remaining_keys()
154            .into_iter()
155            .filter_map(|k| {
156                k.strip_prefix(crate::load::PREFIX_ACOUSTIC)
157                    .map(str::to_string)
158            })
159            .collect()
160    }
161}
162
163/// Clone-on-take loader for one-shot compiles (acoustic snapshot).
164pub struct SnapshotLoader {
165    map: HashMap<String, (Vec<f32>, Vec<usize>)>,
166}
167
168impl SnapshotLoader {
169    pub fn new(map: HashMap<String, (Vec<f32>, Vec<usize>)>) -> Self {
170        Self { map }
171    }
172}
173
174impl WeightLoader for SnapshotLoader {
175    fn len(&self) -> usize {
176        self.map.len()
177    }
178
179    fn take(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
180        self.map
181            .remove(key)
182            .ok_or_else(|| anyhow::anyhow!("missing weight {key}"))
183    }
184
185    fn take_transposed(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
186        self.take(key)
187    }
188
189    fn remaining_keys(&self) -> Vec<String> {
190        self.map.keys().cloned().collect()
191    }
192}
193
194pub(crate) fn snapshot_backbone_params(
195    store: &crate::load::VoxtralTtsWeightStore,
196) -> anyhow::Result<crate::load::WeightSnapshot> {
197    let mut wm = store.load_backbone()?;
198    let keys: Vec<String> = wm.keys().map(str::to_string).collect();
199    let mut out = HashMap::with_capacity(keys.len());
200    for key in keys {
201        out.insert(key.clone(), wm.take(&key)?);
202    }
203    Ok(out)
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn maps_hf_layer_keys_to_mistral_names() {
212        assert_eq!(
213            BackbonePrefixLoader::map_key("model.layers.3.input_layernorm.weight"),
214            "layers.3.attention_norm.weight"
215        );
216        assert_eq!(
217            BackbonePrefixLoader::map_key("model.layers.0.mlp.gate_proj.weight"),
218            "layers.0.feed_forward.w1.weight"
219        );
220        assert_eq!(
221            BackbonePrefixLoader::map_key("model.norm.weight"),
222            "norm.weight"
223        );
224    }
225}