rlx_locateanything/
weights.rs1use anyhow::{Context, Result};
19use rlx_core::weight_loader::WeightLoader;
20use std::sync::Arc;
21
22use crate::load::{LocateAnythingWeightStore, WeightSnapshot};
23
24#[derive(Debug, Clone)]
26pub struct LocateAnythingWeightPrefix;
27
28impl LocateAnythingWeightPrefix {
29 pub fn vision_block(i: usize, suffix: &str) -> String {
30 format!("vision_model.encoder.blocks.{i}.{suffix}")
31 }
32
33 pub fn vision_patch_proj_w() -> &'static str {
34 "vision_model.patch_embed.proj.weight"
35 }
36
37 pub fn vision_patch_proj_b() -> &'static str {
38 "vision_model.patch_embed.proj.bias"
39 }
40
41 pub fn vision_pos_emb() -> &'static str {
42 "vision_model.patch_embed.pos_emb.weight"
43 }
44
45 pub fn vision_final_ln_w() -> &'static str {
46 "vision_model.encoder.final_layernorm.weight"
47 }
48
49 pub fn vision_final_ln_b() -> &'static str {
50 "vision_model.encoder.final_layernorm.bias"
51 }
52
53 pub fn projector_ln_w() -> &'static str {
54 "mlp1.0.weight"
55 }
56
57 pub fn projector_ln_b() -> &'static str {
58 "mlp1.0.bias"
59 }
60
61 pub fn projector_fc1_w() -> &'static str {
62 "mlp1.1.weight"
63 }
64
65 pub fn projector_fc1_b() -> &'static str {
66 "mlp1.1.bias"
67 }
68
69 pub fn projector_fc2_w() -> &'static str {
70 "mlp1.3.weight"
71 }
72
73 pub fn projector_fc2_b() -> &'static str {
74 "mlp1.3.bias"
75 }
76
77 pub fn lm_embed_tokens() -> &'static str {
78 "language_model.model.embed_tokens.weight"
79 }
80
81 pub fn lm_head() -> &'static str {
82 "language_model.lm_head.weight"
83 }
84
85 pub fn lm_layer(i: usize, suffix: &str) -> String {
86 format!("language_model.model.layers.{i}.{suffix}")
87 }
88
89 pub fn lm_norm() -> &'static str {
90 "language_model.model.norm.weight"
91 }
92}
93
94fn map_lm_key(key: &str) -> String {
95 match key {
96 "model.embed_tokens.weight" => LocateAnythingWeightPrefix::lm_embed_tokens().into(),
97 "model.norm.weight" => LocateAnythingWeightPrefix::lm_norm().into(),
98 "lm_head.weight" => LocateAnythingWeightPrefix::lm_head().into(),
99 k if k.starts_with("model.layers.") => format!("language_model.{k}"),
100 other => other.into(),
101 }
102}
103
104pub struct LanguageModelPrefixLoader<'a> {
106 inner: &'a mut dyn WeightLoader,
107}
108
109impl<'a> LanguageModelPrefixLoader<'a> {
110 pub fn new(inner: &'a mut dyn WeightLoader) -> Self {
111 Self { inner }
112 }
113}
114
115impl WeightLoader for LanguageModelPrefixLoader<'_> {
116 fn len(&self) -> usize {
117 self.inner.len()
118 }
119
120 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
121 self.inner.take(&map_lm_key(key))
122 }
123
124 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
125 self.inner.take_transposed(&map_lm_key(key))
126 }
127
128 fn remaining_keys(&self) -> Vec<String> {
129 self.inner.remaining_keys()
130 }
131}
132
133pub struct CheckpointLmWeightLoader {
135 store: Arc<LocateAnythingWeightStore>,
136}
137
138impl CheckpointLmWeightLoader {
139 pub fn new(store: Arc<LocateAnythingWeightStore>) -> Self {
140 Self { store }
141 }
142
143 fn take_hf(&self, hf: &str) -> Result<(Vec<f32>, Vec<usize>)> {
144 let mut wm = self
145 .store
146 .load_keys(&[hf])
147 .with_context(|| format!("load LM weight {hf}"))?;
148 wm.take(hf)
149 .with_context(|| format!("missing LM weight {hf} after load"))
150 }
151}
152
153impl WeightLoader for CheckpointLmWeightLoader {
154 fn len(&self) -> usize {
155 self.store
156 .count_keys_with_prefix(crate::load::PREFIX_LANGUAGE_MODEL)
157 }
158
159 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
160 self.take_hf(&map_lm_key(key))
161 }
162
163 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
164 let hf = map_lm_key(key);
165 let (data, shape) = self.take_hf(&hf)?;
166 if shape.len() != 2 {
167 anyhow::bail!("transpose requires rank-2 weight: {key}");
168 }
169 let rows = shape[0];
170 let cols = shape[1];
171 let mut out = vec![0f32; rows * cols];
172 for r in 0..rows {
173 for c in 0..cols {
174 out[c * rows + r] = data[r * cols + c];
175 }
176 }
177 Ok((out, vec![cols, rows]))
178 }
179
180 fn remaining_keys(&self) -> Vec<String> {
181 self.store
182 .keys()
183 .iter()
184 .filter(|k| k.starts_with(crate::load::PREFIX_LANGUAGE_MODEL))
185 .cloned()
186 .collect()
187 }
188}
189
190pub struct ArcLmWeightLoader {
192 snapshot: Arc<WeightSnapshot>,
193}
194
195impl ArcLmWeightLoader {
196 pub fn new(snapshot: Arc<WeightSnapshot>) -> Self {
197 Self { snapshot }
198 }
199}
200
201impl WeightLoader for ArcLmWeightLoader {
202 fn len(&self) -> usize {
203 self.snapshot.len()
204 }
205
206 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
207 let hf = map_lm_key(key);
208 let (data, shape) = self
209 .snapshot
210 .get(&hf)
211 .with_context(|| format!("missing weight {hf}"))?;
212 Ok((data.clone(), shape.clone()))
213 }
214
215 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
216 let hf = map_lm_key(key);
217 let (data, shape) = self
218 .snapshot
219 .get(&hf)
220 .with_context(|| format!("missing weight {hf}"))?;
221 if shape.len() != 2 {
222 anyhow::bail!("transpose requires rank-2 weight: {key}");
223 }
224 let rows = shape[0];
225 let cols = shape[1];
226 let mut out = vec![0f32; rows * cols];
227 for r in 0..rows {
228 for c in 0..cols {
229 out[c * rows + r] = data[r * cols + c];
230 }
231 }
232 Ok((out, vec![cols, rows]))
233 }
234
235 fn remaining_keys(&self) -> Vec<String> {
236 self.snapshot.keys().cloned().collect()
237 }
238}