Skip to main content

rlx_locateanything/
weights.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Checkpoint tensor names for `nvidia/LocateAnything-3B` safetensors.
17
18use anyhow::{Context, Result};
19use rlx_core::weight_loader::WeightLoader;
20use std::sync::Arc;
21
22use crate::load::{LocateAnythingWeightStore, WeightSnapshot};
23
24/// HF weight prefix helpers (`vision_model.*`, `mlp1.*`, `language_model.*`).
25#[derive(Debug, Clone)]
26pub struct LocateAnythingWeightPrefix;
27
28impl LocateAnythingWeightPrefix {
29    pub fn vision_block(i: usize, suffix: &str) -> String {
30        format!("vision_model.encoder.blocks.{i}.{suffix}")
31    }
32
33    pub fn vision_patch_proj_w() -> &'static str {
34        "vision_model.patch_embed.proj.weight"
35    }
36
37    pub fn vision_patch_proj_b() -> &'static str {
38        "vision_model.patch_embed.proj.bias"
39    }
40
41    pub fn vision_pos_emb() -> &'static str {
42        "vision_model.patch_embed.pos_emb.weight"
43    }
44
45    pub fn vision_final_ln_w() -> &'static str {
46        "vision_model.encoder.final_layernorm.weight"
47    }
48
49    pub fn vision_final_ln_b() -> &'static str {
50        "vision_model.encoder.final_layernorm.bias"
51    }
52
53    pub fn projector_ln_w() -> &'static str {
54        "mlp1.0.weight"
55    }
56
57    pub fn projector_ln_b() -> &'static str {
58        "mlp1.0.bias"
59    }
60
61    pub fn projector_fc1_w() -> &'static str {
62        "mlp1.1.weight"
63    }
64
65    pub fn projector_fc1_b() -> &'static str {
66        "mlp1.1.bias"
67    }
68
69    pub fn projector_fc2_w() -> &'static str {
70        "mlp1.3.weight"
71    }
72
73    pub fn projector_fc2_b() -> &'static str {
74        "mlp1.3.bias"
75    }
76
77    pub fn lm_embed_tokens() -> &'static str {
78        "language_model.model.embed_tokens.weight"
79    }
80
81    pub fn lm_head() -> &'static str {
82        "language_model.lm_head.weight"
83    }
84
85    pub fn lm_layer(i: usize, suffix: &str) -> String {
86        format!("language_model.model.layers.{i}.{suffix}")
87    }
88
89    pub fn lm_norm() -> &'static str {
90        "language_model.model.norm.weight"
91    }
92}
93
94fn map_lm_key(key: &str) -> String {
95    match key {
96        "model.embed_tokens.weight" => LocateAnythingWeightPrefix::lm_embed_tokens().into(),
97        "model.norm.weight" => LocateAnythingWeightPrefix::lm_norm().into(),
98        "lm_head.weight" => LocateAnythingWeightPrefix::lm_head().into(),
99        k if k.starts_with("model.layers.") => format!("language_model.{k}"),
100        other => other.into(),
101    }
102}
103
104/// Maps Qwen-shaped keys (`model.*`, `lm_head.*`) to HF `language_model.*` names.
105pub struct LanguageModelPrefixLoader<'a> {
106    inner: &'a mut dyn WeightLoader,
107}
108
109impl<'a> LanguageModelPrefixLoader<'a> {
110    pub fn new(inner: &'a mut dyn WeightLoader) -> Self {
111        Self { inner }
112    }
113}
114
115impl WeightLoader for LanguageModelPrefixLoader<'_> {
116    fn len(&self) -> usize {
117        self.inner.len()
118    }
119
120    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
121        self.inner.take(&map_lm_key(key))
122    }
123
124    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
125        self.inner.take_transposed(&map_lm_key(key))
126    }
127
128    fn remaining_keys(&self) -> Vec<String> {
129        self.inner.remaining_keys()
130    }
131}
132
133/// LM weights loaded on demand from mmap-backed safetensors (no full-RAM snapshot).
134pub struct CheckpointLmWeightLoader {
135    store: Arc<LocateAnythingWeightStore>,
136}
137
138impl CheckpointLmWeightLoader {
139    pub fn new(store: Arc<LocateAnythingWeightStore>) -> Self {
140        Self { store }
141    }
142
143    fn take_hf(&self, hf: &str) -> Result<(Vec<f32>, Vec<usize>)> {
144        let mut wm = self
145            .store
146            .load_keys(&[hf])
147            .with_context(|| format!("load LM weight {hf}"))?;
148        wm.take(hf)
149            .with_context(|| format!("missing LM weight {hf} after load"))
150    }
151}
152
153impl WeightLoader for CheckpointLmWeightLoader {
154    fn len(&self) -> usize {
155        self.store
156            .count_keys_with_prefix(crate::load::PREFIX_LANGUAGE_MODEL)
157    }
158
159    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
160        self.take_hf(&map_lm_key(key))
161    }
162
163    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
164        let hf = map_lm_key(key);
165        let (data, shape) = self.take_hf(&hf)?;
166        if shape.len() != 2 {
167            anyhow::bail!("transpose requires rank-2 weight: {key}");
168        }
169        let rows = shape[0];
170        let cols = shape[1];
171        let mut out = vec![0f32; rows * cols];
172        for r in 0..rows {
173            for c in 0..cols {
174                out[c * rows + r] = data[r * cols + c];
175            }
176        }
177        Ok((out, vec![cols, rows]))
178    }
179
180    fn remaining_keys(&self) -> Vec<String> {
181        self.store
182            .keys()
183            .iter()
184            .filter(|k| k.starts_with(crate::load::PREFIX_LANGUAGE_MODEL))
185            .cloned()
186            .collect()
187    }
188}
189
190/// LM weights from a shared snapshot — one tensor cloned per `take`, not the full map.
191pub struct ArcLmWeightLoader {
192    snapshot: Arc<WeightSnapshot>,
193}
194
195impl ArcLmWeightLoader {
196    pub fn new(snapshot: Arc<WeightSnapshot>) -> Self {
197        Self { snapshot }
198    }
199}
200
201impl WeightLoader for ArcLmWeightLoader {
202    fn len(&self) -> usize {
203        self.snapshot.len()
204    }
205
206    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
207        let hf = map_lm_key(key);
208        let (data, shape) = self
209            .snapshot
210            .get(&hf)
211            .with_context(|| format!("missing weight {hf}"))?;
212        Ok((data.clone(), shape.clone()))
213    }
214
215    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
216        let hf = map_lm_key(key);
217        let (data, shape) = self
218            .snapshot
219            .get(&hf)
220            .with_context(|| format!("missing weight {hf}"))?;
221        if shape.len() != 2 {
222            anyhow::bail!("transpose requires rank-2 weight: {key}");
223        }
224        let rows = shape[0];
225        let cols = shape[1];
226        let mut out = vec![0f32; rows * cols];
227        for r in 0..rows {
228            for c in 0..cols {
229                out[c * rows + r] = data[r * cols + c];
230            }
231        }
232        Ok((out, vec![cols, rows]))
233    }
234
235    fn remaining_keys(&self) -> Vec<String> {
236        self.snapshot.keys().cloned().collect()
237    }
238}