rlx_voxtral_tts/
weights.rs1use rlx_core::weight_loader::WeightLoader;
19use std::collections::HashMap;
20
21pub struct BackbonePrefixLoader<'a> {
24 inner: &'a mut dyn WeightLoader,
25}
26
27impl<'a> BackbonePrefixLoader<'a> {
28 pub fn new(inner: &'a mut dyn WeightLoader) -> Self {
29 Self { inner }
30 }
31
32 pub fn map_key(key: &str) -> String {
33 match key {
34 "model.norm.weight" => "norm.weight".into(),
35 "model.embed_tokens.weight" => "tok_embeddings.weight".into(),
36 "lm_head.weight" => "output.weight".into(),
37 k if k.starts_with("model.layers.") => map_layer_key(k),
38 k if k.starts_with("layers.") => k.to_string(),
39 other => other.to_string(),
40 }
41 }
42}
43
44fn map_layer_key(key: &str) -> String {
45 let rest = key.strip_prefix("model.layers.").unwrap_or(key);
46 let Some(dot) = rest.find('.') else {
47 return key.to_string();
48 };
49 let (idx, tail) = rest.split_at(dot);
50 let tail = &tail[1..];
51 let mapped = match tail {
52 "input_layernorm.weight" => "attention_norm.weight",
53 "post_attention_layernorm.weight" => "ffn_norm.weight",
54 "self_attn.q_proj.weight" => "attention.wq.weight",
55 "self_attn.k_proj.weight" => "attention.wk.weight",
56 "self_attn.v_proj.weight" => "attention.wv.weight",
57 "self_attn.o_proj.weight" => "attention.wo.weight",
58 "mlp.gate_proj.weight" => "feed_forward.w1.weight",
59 "mlp.up_proj.weight" => "feed_forward.w3.weight",
60 "mlp.down_proj.weight" => "feed_forward.w2.weight",
61 "gate_proj.weight" => "feed_forward.w1.weight",
62 "up_proj.weight" => "feed_forward.w3.weight",
63 "down_proj.weight" => "feed_forward.w2.weight",
64 other => other,
65 };
66 format!("layers.{idx}.{mapped}")
67}
68
69impl WeightLoader for BackbonePrefixLoader<'_> {
70 fn len(&self) -> usize {
71 self.inner.len()
72 }
73
74 fn take(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
75 self.inner.take(&Self::map_key(key))
76 }
77
78 fn take_transposed(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
79 self.inner.take_transposed(&Self::map_key(key))
80 }
81
82 fn remaining_keys(&self) -> Vec<String> {
83 self.inner.remaining_keys()
84 }
85}
86
87pub struct CheckpointParamLoader {
89 params: HashMap<String, (Vec<f32>, Vec<usize>)>,
90}
91
92impl CheckpointParamLoader {
93 pub fn new(params: HashMap<String, (Vec<f32>, Vec<usize>)>) -> Self {
94 Self { params }
95 }
96}
97
98impl WeightLoader for CheckpointParamLoader {
99 fn len(&self) -> usize {
100 self.params.len()
101 }
102
103 fn take(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
104 self.params
105 .get(key)
106 .cloned()
107 .ok_or_else(|| anyhow::anyhow!("missing weight {key}"))
108 }
109
110 fn take_transposed(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
111 self.take(key)
112 }
113
114 fn remaining_keys(&self) -> Vec<String> {
115 self.params.keys().cloned().collect()
116 }
117}
118
119pub struct AcousticPrefixLoader<'a> {
121 inner: &'a mut rlx_core::weight_map::WeightMap,
122}
123
124impl<'a> AcousticPrefixLoader<'a> {
125 pub fn new(inner: &'a mut rlx_core::weight_map::WeightMap) -> Self {
126 Self { inner }
127 }
128
129 fn full_key(key: &str) -> String {
130 if key.starts_with(crate::load::PREFIX_ACOUSTIC) {
131 key.to_string()
132 } else {
133 format!("{}{key}", crate::load::PREFIX_ACOUSTIC)
134 }
135 }
136}
137
138impl WeightLoader for AcousticPrefixLoader<'_> {
139 fn len(&self) -> usize {
140 self.inner.len()
141 }
142
143 fn take(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
144 self.inner.take(&Self::full_key(key))
145 }
146
147 fn take_transposed(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
148 self.inner.take_transposed(&Self::full_key(key))
149 }
150
151 fn remaining_keys(&self) -> Vec<String> {
152 self.inner
153 .remaining_keys()
154 .into_iter()
155 .filter_map(|k| {
156 k.strip_prefix(crate::load::PREFIX_ACOUSTIC)
157 .map(str::to_string)
158 })
159 .collect()
160 }
161}
162
163pub struct SnapshotLoader {
165 map: HashMap<String, (Vec<f32>, Vec<usize>)>,
166}
167
168impl SnapshotLoader {
169 pub fn new(map: HashMap<String, (Vec<f32>, Vec<usize>)>) -> Self {
170 Self { map }
171 }
172}
173
174impl WeightLoader for SnapshotLoader {
175 fn len(&self) -> usize {
176 self.map.len()
177 }
178
179 fn take(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
180 self.map
181 .remove(key)
182 .ok_or_else(|| anyhow::anyhow!("missing weight {key}"))
183 }
184
185 fn take_transposed(&mut self, key: &str) -> anyhow::Result<(Vec<f32>, Vec<usize>)> {
186 self.take(key)
187 }
188
189 fn remaining_keys(&self) -> Vec<String> {
190 self.map.keys().cloned().collect()
191 }
192}
193
194pub(crate) fn snapshot_backbone_params(
195 store: &crate::load::VoxtralTtsWeightStore,
196) -> anyhow::Result<crate::load::WeightSnapshot> {
197 let mut wm = store.load_backbone()?;
198 let keys: Vec<String> = wm.keys().map(str::to_string).collect();
199 let mut out = HashMap::with_capacity(keys.len());
200 for key in keys {
201 out.insert(key.clone(), wm.take(&key)?);
202 }
203 Ok(out)
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn maps_hf_layer_keys_to_mistral_names() {
212 assert_eq!(
213 BackbonePrefixLoader::map_key("model.layers.3.input_layernorm.weight"),
214 "layers.3.attention_norm.weight"
215 );
216 assert_eq!(
217 BackbonePrefixLoader::map_key("model.layers.0.mlp.gate_proj.weight"),
218 "layers.0.feed_forward.w1.weight"
219 );
220 assert_eq!(
221 BackbonePrefixLoader::map_key("model.norm.weight"),
222 "norm.weight"
223 );
224 }
225}