Skip to main content

rlx_voxtral/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Checkpoint tensor names for `mistralai/Voxtral-*` safetensors.
17
18use anyhow::Result;
19use rlx_core::weight_loader::WeightLoader;
20use rlx_core::weight_map::WeightMap;
21
22/// HF weight prefix helpers (`audio_tower.*`, `language_model.*`, `multi_modal_projector.*`).
23#[derive(Debug, Clone)]
24pub struct VoxtralWeightPrefix;
25
26impl VoxtralWeightPrefix {
27    pub fn enc_layer(i: usize, suffix: &str) -> String {
28        format!("audio_tower.layers.{i}.{suffix}")
29    }
30
31    pub fn enc_conv1_w() -> &'static str {
32        "audio_tower.conv1.weight"
33    }
34
35    pub fn enc_conv1_b() -> &'static str {
36        "audio_tower.conv1.bias"
37    }
38
39    pub fn enc_conv2_w() -> &'static str {
40        "audio_tower.conv2.weight"
41    }
42
43    pub fn enc_conv2_b() -> &'static str {
44        "audio_tower.conv2.bias"
45    }
46
47    pub fn enc_embed_positions() -> &'static str {
48        "audio_tower.embed_positions.weight"
49    }
50
51    pub fn enc_ln_post_w() -> &'static str {
52        "audio_tower.layer_norm.weight"
53    }
54
55    pub fn enc_ln_post_b() -> &'static str {
56        "audio_tower.layer_norm.bias"
57    }
58
59    pub fn projector_linear1() -> &'static str {
60        "multi_modal_projector.linear_1.weight"
61    }
62
63    pub fn projector_linear2() -> &'static str {
64        "multi_modal_projector.linear_2.weight"
65    }
66
67    pub fn lm_embed_tokens() -> &'static str {
68        "language_model.model.embed_tokens.weight"
69    }
70
71    pub fn lm_head() -> &'static str {
72        "language_model.lm_head.weight"
73    }
74
75    pub fn lm_layer(i: usize, suffix: &str) -> String {
76        format!("language_model.model.layers.{i}.{suffix}")
77    }
78
79    pub fn lm_norm() -> &'static str {
80        "language_model.model.norm.weight"
81    }
82}
83
84fn map_lm_key(key: &str) -> String {
85    match key {
86        "model.embed_tokens.weight" => VoxtralWeightPrefix::lm_embed_tokens().to_string(),
87        "model.norm.weight" => VoxtralWeightPrefix::lm_norm().to_string(),
88        "lm_head.weight" => VoxtralWeightPrefix::lm_head().to_string(),
89        k if k.starts_with("model.layers.") => format!("language_model.{k}"),
90        other => other.to_string(),
91    }
92}
93
94/// Maps Llama-shaped keys (`model.*`, `lm_head.*`) to Voxtral safetensor names.
95pub struct LanguageModelPrefixLoader<'a> {
96    inner: &'a mut WeightMap,
97}
98
99impl<'a> LanguageModelPrefixLoader<'a> {
100    pub fn new(inner: &'a mut WeightMap) -> Self {
101        Self { inner }
102    }
103}
104
105impl WeightLoader for LanguageModelPrefixLoader<'_> {
106    fn len(&self) -> usize {
107        self.inner.len()
108    }
109
110    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
111        self.inner.take(&map_lm_key(key))
112    }
113
114    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
115        self.inner.take_transposed(&map_lm_key(key))
116    }
117
118    fn remaining_keys(&self) -> Vec<String> {
119        self.inner.remaining_keys()
120    }
121}