1use anyhow::{Context, Result, bail, ensure};
19use std::collections::{HashMap, HashSet};
20use std::path::Path;
21
22use crate::gguf_support::{
23 gguf_architecture_from_path, gguf_safetensors_only_hint, resolve_weights_file,
24};
25use crate::weight_loader::WeightLoader;
26use crate::weight_registry::{LoadWeightsOptions, load_weight_map_resolved};
27use rlx_ir::quant::QuantScheme;
28
29pub type PackedWeightTensor = (Vec<u8>, QuantScheme, Vec<usize>);
31pub type NamedPackedWeight = (String, Vec<u8>, QuantScheme, Vec<usize>);
33pub type F32WeightSnapshot = HashMap<String, (Vec<f32>, Vec<usize>)>;
35
36#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
38pub enum WeightDrainPolicy {
39 #[default]
40 AllF32,
41 AllF32WarnUnused,
43 AllF32StrictUnused,
45}
46
47pub struct WeightMap {
49 tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
50}
51
52impl WeightMap {
53 pub fn from_weight_loader(loader: &mut dyn WeightLoader) -> Result<Self> {
55 Self::drain_loader(loader, WeightDrainPolicy::AllF32).map(|(m, _)| m)
56 }
57
58 pub fn from_weight_loader_dequant_all(loader: &mut dyn WeightLoader) -> Result<Self> {
65 let keys = loader.remaining_keys();
66 let mut tensors = HashMap::with_capacity(keys.len());
67 for key in keys {
68 let (data, shape) = loader.take(&key)?;
69 tensors.insert(key, (data, shape));
70 }
71 Ok(Self { tensors })
72 }
73
74 pub fn drain_loader(
76 loader: &mut dyn WeightLoader,
77 policy: WeightDrainPolicy,
78 ) -> Result<(Self, Vec<NamedPackedWeight>)> {
79 let keys = loader.remaining_keys();
80 let mut tensors = HashMap::with_capacity(keys.len());
81 let mut packed = Vec::new();
82 for key in keys {
83 if let Some((bytes, scheme, shape)) = loader.take_packed(&key)? {
84 packed.push((key, bytes, scheme, shape));
85 continue;
86 }
87 let (data, shape) = loader.take(&key)?;
88 tensors.insert(key, (data, shape));
89 }
90 let left = loader.remaining_keys();
91 match policy {
92 WeightDrainPolicy::AllF32 => {}
93 WeightDrainPolicy::AllF32WarnUnused if !left.is_empty() => {
94 eprintln!(
95 "[rlx-core] weight drain: {} unused tensors (format={})",
96 left.len(),
97 loader.format_id()
98 );
99 for k in left.iter().take(8) {
100 eprintln!(" unused: {k}");
101 }
102 if left.len() > 8 {
103 eprintln!(" … and {} more", left.len() - 8);
104 }
105 }
106 WeightDrainPolicy::AllF32StrictUnused if !left.is_empty() => {
107 bail!(
108 "weight drain left {} unused tensors (format={}): {:?}",
109 left.len(),
110 loader.format_id(),
111 &left[..left.len().min(5)]
112 );
113 }
114 _ => {}
115 }
116 Ok((Self { tensors }, packed))
117 }
118
119 pub fn from_resolved_path(path: &Path) -> Result<Self> {
121 let file = resolve_weights_file(path)?;
122 Self::from_resolved_file(&file)
123 }
124
125 pub fn from_resolved_safetensors_only(path: &Path, runner: &str) -> Result<Self> {
127 let file = resolve_weights_file(path)?;
128 if file.extension().and_then(|s| s.to_str()) == Some("gguf") {
129 let arch = gguf_architecture_from_path(&file)?;
130 bail!(gguf_safetensors_only_hint(runner, &file, &arch));
131 }
132 Self::from_resolved_file(&file)
133 }
134
135 fn from_resolved_file(file: &Path) -> Result<Self> {
136 load_weight_map_resolved(file, LoadWeightsOptions::map()).map(|(_, m)| m)
137 }
138
139 pub fn from_file(path: &str) -> Result<Self> {
141 Self::from_file_excluding(path, &HashSet::new())
142 }
143
144 pub fn from_file_excluding(path: &str, exclude: &HashSet<String>) -> Result<Self> {
147 let data = std::fs::read(path).with_context(|| format!("reading {path}"))?;
148 let st =
149 safetensors::SafeTensors::deserialize(&data).with_context(|| "parsing safetensors")?;
150
151 let mut tensors = HashMap::new();
152 for (name, view) in st.tensors() {
153 if exclude.contains(name.as_str()) {
154 continue;
155 }
156 let shape: Vec<usize> = view.shape().to_vec();
157 let bytes = view.data();
158 let f32_data = match view.dtype() {
159 safetensors::Dtype::F32 => bytes_to_f32_vec(bytes),
160 safetensors::Dtype::F16 => bytes
161 .chunks_exact(2)
162 .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
163 .collect(),
164 safetensors::Dtype::BF16 => bytes
165 .chunks_exact(2)
166 .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
167 .collect(),
168 safetensors::Dtype::I64 => bytes
169 .chunks_exact(8)
170 .map(|c| i64::from_le_bytes(c.try_into().unwrap()) as f32)
171 .collect(),
172 safetensors::Dtype::I32 => bytes
173 .chunks_exact(4)
174 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
175 .collect(),
176 safetensors::Dtype::C64 => {
177 continue;
182 }
183 other => anyhow::bail!("unsupported dtype: {other:?}"),
184 };
185 tensors.insert(name.to_string(), (f32_data, shape));
186 }
187
188 Ok(Self { tensors })
189 }
190
191 pub fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
193 self.tensors
194 .remove(key)
195 .ok_or_else(|| anyhow::anyhow!("weight not found: {key}"))
196 }
197
198 pub fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
200 let (data, shape) = self.take(key)?;
201 if shape.len() != 2 {
202 anyhow::bail!("transpose requires 2D, got {shape:?}");
203 }
204 let (rows, cols) = (shape[0], shape[1]);
205 let mut transposed = vec![0f32; data.len()];
206 for i in 0..rows {
207 for j in 0..cols {
208 transposed[j * rows + i] = data[i * cols + j];
209 }
210 }
211 Ok((transposed, vec![cols, rows]))
212 }
213
214 pub fn has(&self, key: &str) -> bool {
216 self.tensors.contains_key(key)
217 }
218
219 pub fn keys(&self) -> impl Iterator<Item = &str> {
221 self.tensors.keys().map(|s| s.as_str())
222 }
223
224 pub fn len(&self) -> usize {
226 self.tensors.len()
227 }
228 pub fn is_empty(&self) -> bool {
229 self.tensors.is_empty()
230 }
231
232 pub fn from_tensors(tensors: HashMap<String, (Vec<f32>, Vec<usize>)>) -> Self {
234 Self { tensors }
235 }
236
237 pub fn snapshot_from_path(path: &str) -> Result<F32WeightSnapshot> {
239 let mut wm = Self::from_file(path)?;
240 let keys: Vec<String> = wm.keys().map(|s| s.to_string()).collect();
241 let mut out = HashMap::with_capacity(keys.len());
242 for k in keys {
243 out.insert(k.clone(), wm.take(&k)?);
244 }
245 Ok(out)
246 }
247
248 pub fn from_safetensors_dir_selected(dir: &Path, want: &HashSet<String>) -> Result<Self> {
250 crate::safetensors_checkpoint::SafetensorsCheckpoint::open(dir)?.load_selected(want)
251 }
252
253 pub fn from_safetensors_dir(dir: &Path) -> Result<Self> {
255 let mut merged = HashMap::new();
256 let mut any = false;
257 for entry in std::fs::read_dir(dir).with_context(|| format!("read_dir {dir:?}"))? {
258 let entry = entry?;
259 let path = entry.path();
260 if path.extension().and_then(|s| s.to_str()) != Some("safetensors") {
261 continue;
262 }
263 let part = Self::from_file(
264 path.to_str()
265 .ok_or_else(|| anyhow::anyhow!("non-utf8 path {:?}", path))?,
266 )?;
267 for (k, v) in part.tensors {
268 merged.insert(k, v);
269 }
270 any = true;
271 }
272 if !any {
273 anyhow::bail!("no .safetensors files in {dir:?}");
274 }
275 Ok(Self { tensors: merged })
276 }
277
278 pub fn remap_keys<F>(&mut self, mut f: F)
280 where
281 F: FnMut(String) -> String,
282 {
283 let keys: Vec<String> = self.tensors.keys().cloned().collect();
284 for old in keys {
285 if let Some(v) = self.tensors.remove(&old) {
286 let new = f(old);
287 self.tensors.insert(new, v);
288 }
289 }
290 }
291
292 pub fn get(&self, key: &str) -> Option<(&[f32], &[usize])> {
294 self.tensors
295 .get(key)
296 .map(|(d, s)| (d.as_slice(), s.as_slice()))
297 }
298
299 pub fn merge_add_weight(&mut self, key: &str, delta: &[f32]) -> Result<()> {
301 let entry = self
302 .tensors
303 .get_mut(key)
304 .with_context(|| format!("merge_add_weight: missing {key}"))?;
305 let (data, shape) = entry;
306 ensure!(
307 shape.len() == 2,
308 "merge_add_weight {key}: expected rank-2, got {shape:?}"
309 );
310 ensure!(
311 data.len() == delta.len(),
312 "merge_add_weight {key}: len {} != delta {}",
313 data.len(),
314 delta.len()
315 );
316 for (d, s) in data.iter_mut().zip(delta.iter()) {
317 *d += s;
318 }
319 Ok(())
320 }
321}
322
323pub(crate) fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
329 debug_assert!(
330 bytes.len().is_multiple_of(4),
331 "f32 byte slice length must be multiple of 4 (got {})",
332 bytes.len()
333 );
334 if (bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
335 let f32s: &[f32] = bytemuck::cast_slice(bytes);
336 f32s.to_vec()
337 } else {
338 bytes
339 .chunks_exact(4)
340 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
341 .collect()
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn transpose_2d() {
351 let mut wm = WeightMap {
352 tensors: HashMap::from([(
353 "w".to_string(),
354 (vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]),
355 )]),
356 };
357 let (data, shape) = wm.take_transposed("w").unwrap();
358 assert_eq!(shape, vec![3, 2]);
359 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
361 }
362}