1use anyhow::Result;
19use rlx_core::weight_loader::WeightLoader;
20use rlx_core::weight_map::WeightMap;
21
22#[derive(Debug, Clone)]
24pub struct VoxtralWeightPrefix;
25
26impl VoxtralWeightPrefix {
27 pub fn enc_layer(i: usize, suffix: &str) -> String {
28 format!("audio_tower.layers.{i}.{suffix}")
29 }
30
31 pub fn enc_conv1_w() -> &'static str {
32 "audio_tower.conv1.weight"
33 }
34
35 pub fn enc_conv1_b() -> &'static str {
36 "audio_tower.conv1.bias"
37 }
38
39 pub fn enc_conv2_w() -> &'static str {
40 "audio_tower.conv2.weight"
41 }
42
43 pub fn enc_conv2_b() -> &'static str {
44 "audio_tower.conv2.bias"
45 }
46
47 pub fn enc_embed_positions() -> &'static str {
48 "audio_tower.embed_positions.weight"
49 }
50
51 pub fn enc_ln_post_w() -> &'static str {
52 "audio_tower.layer_norm.weight"
53 }
54
55 pub fn enc_ln_post_b() -> &'static str {
56 "audio_tower.layer_norm.bias"
57 }
58
59 pub fn projector_linear1() -> &'static str {
60 "multi_modal_projector.linear_1.weight"
61 }
62
63 pub fn projector_linear2() -> &'static str {
64 "multi_modal_projector.linear_2.weight"
65 }
66
67 pub fn lm_embed_tokens() -> &'static str {
68 "language_model.model.embed_tokens.weight"
69 }
70
71 pub fn lm_head() -> &'static str {
72 "language_model.lm_head.weight"
73 }
74
75 pub fn lm_layer(i: usize, suffix: &str) -> String {
76 format!("language_model.model.layers.{i}.{suffix}")
77 }
78
79 pub fn lm_norm() -> &'static str {
80 "language_model.model.norm.weight"
81 }
82}
83
84fn map_lm_key(key: &str) -> String {
85 match key {
86 "model.embed_tokens.weight" => VoxtralWeightPrefix::lm_embed_tokens().to_string(),
87 "model.norm.weight" => VoxtralWeightPrefix::lm_norm().to_string(),
88 "lm_head.weight" => VoxtralWeightPrefix::lm_head().to_string(),
89 k if k.starts_with("model.layers.") => format!("language_model.{k}"),
90 other => other.to_string(),
91 }
92}
93
94pub struct LanguageModelPrefixLoader<'a> {
96 inner: &'a mut WeightMap,
97}
98
99impl<'a> LanguageModelPrefixLoader<'a> {
100 pub fn new(inner: &'a mut WeightMap) -> Self {
101 Self { inner }
102 }
103}
104
105impl WeightLoader for LanguageModelPrefixLoader<'_> {
106 fn len(&self) -> usize {
107 self.inner.len()
108 }
109
110 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
111 self.inner.take(&map_lm_key(key))
112 }
113
114 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
115 self.inner.take_transposed(&map_lm_key(key))
116 }
117
118 fn remaining_keys(&self) -> Vec<String> {
119 self.inner.remaining_keys()
120 }
121}