1use anyhow::{Context, Result, bail, ensure};
19use std::collections::{HashMap, HashSet};
20use std::path::Path;
21
22use crate::gguf_support::{
23 gguf_architecture_from_path, gguf_safetensors_only_hint, resolve_weights_file,
24};
25use crate::weight_loader::WeightLoader;
26use crate::weight_registry::{LoadWeightsOptions, load_weight_map_resolved};
27use rlx_ir::quant::QuantScheme;
28
29pub type PackedWeightTensor = (Vec<u8>, QuantScheme, Vec<usize>);
31pub type NamedPackedWeight = (String, Vec<u8>, QuantScheme, Vec<usize>);
33pub type F32WeightSnapshot = HashMap<String, (Vec<f32>, Vec<usize>)>;
35
36#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
38pub enum WeightDrainPolicy {
39 #[default]
40 AllF32,
41 AllF32WarnUnused,
43 AllF32StrictUnused,
45}
46
47pub struct WeightMap {
49 tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
50}
51
52impl WeightMap {
53 pub fn from_weight_loader(loader: &mut dyn WeightLoader) -> Result<Self> {
55 Self::drain_loader(loader, WeightDrainPolicy::AllF32).map(|(m, _)| m)
56 }
57
58 pub fn from_weight_loader_dequant_all(loader: &mut dyn WeightLoader) -> Result<Self> {
65 let keys = loader.remaining_keys();
66 let mut tensors = HashMap::with_capacity(keys.len());
67 for key in keys {
68 let (data, shape) = loader.take(&key)?;
69 tensors.insert(key, (data, shape));
70 }
71 Ok(Self { tensors })
72 }
73
74 pub fn drain_loader(
76 loader: &mut dyn WeightLoader,
77 policy: WeightDrainPolicy,
78 ) -> Result<(Self, Vec<NamedPackedWeight>)> {
79 let keys = loader.remaining_keys();
80 let mut tensors = HashMap::with_capacity(keys.len());
81 let mut packed = Vec::new();
82 for key in keys {
83 if let Some((bytes, scheme, shape)) = loader.take_packed(&key)? {
84 packed.push((key, bytes, scheme, shape));
85 continue;
86 }
87 let (data, shape) = loader.take(&key)?;
88 tensors.insert(key, (data, shape));
89 }
90 let left = loader.remaining_keys();
91 match policy {
92 WeightDrainPolicy::AllF32 => {}
93 WeightDrainPolicy::AllF32WarnUnused if !left.is_empty() => {
94 eprintln!(
95 "[rlx-core] weight drain: {} unused tensors (format={})",
96 left.len(),
97 loader.format_id()
98 );
99 for k in left.iter().take(8) {
100 eprintln!(" unused: {k}");
101 }
102 if left.len() > 8 {
103 eprintln!(" … and {} more", left.len() - 8);
104 }
105 }
106 WeightDrainPolicy::AllF32StrictUnused if !left.is_empty() => {
107 bail!(
108 "weight drain left {} unused tensors (format={}): {:?}",
109 left.len(),
110 loader.format_id(),
111 &left[..left.len().min(5)]
112 );
113 }
114 _ => {}
115 }
116 Ok((Self { tensors }, packed))
117 }
118
119 pub fn from_resolved_path(path: &Path) -> Result<Self> {
121 let file = resolve_weights_file(path)?;
122 Self::from_resolved_file(&file)
123 }
124
125 pub fn from_resolved_safetensors_only(path: &Path, runner: &str) -> Result<Self> {
127 let file = resolve_weights_file(path)?;
128 if file.extension().and_then(|s| s.to_str()) == Some("gguf") {
129 let arch = gguf_architecture_from_path(&file)?;
130 bail!(gguf_safetensors_only_hint(runner, &file, &arch));
131 }
132 Self::from_resolved_file(&file)
133 }
134
135 fn from_resolved_file(file: &Path) -> Result<Self> {
136 load_weight_map_resolved(file, LoadWeightsOptions::map()).map(|(_, m)| m)
137 }
138
139 pub fn from_file(path: &str) -> Result<Self> {
141 Self::from_file_excluding(path, &HashSet::new())
142 }
143
144 pub fn from_file_excluding(path: &str, exclude: &HashSet<String>) -> Result<Self> {
147 let data = std::fs::read(path).with_context(|| format!("reading {path}"))?;
148 let st =
149 safetensors::SafeTensors::deserialize(&data).with_context(|| "parsing safetensors")?;
150
151 let mut tensors = HashMap::new();
152 for (name, view) in st.tensors() {
153 if exclude.contains(name.as_str()) {
154 continue;
155 }
156 let shape: Vec<usize> = view.shape().to_vec();
157 let bytes = view.data();
158 let f32_data = match view.dtype() {
159 safetensors::Dtype::F32 => bytemuck_cast_f32(bytes),
160 safetensors::Dtype::F16 => bytes
161 .chunks_exact(2)
162 .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
163 .collect(),
164 safetensors::Dtype::BF16 => bytes
165 .chunks_exact(2)
166 .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
167 .collect(),
168 safetensors::Dtype::I64 => bytes
169 .chunks_exact(8)
170 .map(|c| i64::from_le_bytes(c.try_into().unwrap()) as f32)
171 .collect(),
172 safetensors::Dtype::I32 => bytes
173 .chunks_exact(4)
174 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
175 .collect(),
176 safetensors::Dtype::C64 => {
177 continue;
182 }
183 other => anyhow::bail!("unsupported dtype: {other:?}"),
184 };
185 tensors.insert(name.to_string(), (f32_data, shape));
186 }
187
188 Ok(Self { tensors })
189 }
190
191 pub fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
193 self.tensors
194 .remove(key)
195 .ok_or_else(|| anyhow::anyhow!("weight not found: {key}"))
196 }
197
198 pub fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
200 let (data, shape) = self.take(key)?;
201 if shape.len() != 2 {
202 anyhow::bail!("transpose requires 2D, got {shape:?}");
203 }
204 let (rows, cols) = (shape[0], shape[1]);
205 let mut transposed = vec![0f32; data.len()];
206 for i in 0..rows {
207 for j in 0..cols {
208 transposed[j * rows + i] = data[i * cols + j];
209 }
210 }
211 Ok((transposed, vec![cols, rows]))
212 }
213
214 pub fn has(&self, key: &str) -> bool {
216 self.tensors.contains_key(key)
217 }
218
219 pub fn keys(&self) -> impl Iterator<Item = &str> {
221 self.tensors.keys().map(|s| s.as_str())
222 }
223
224 pub fn len(&self) -> usize {
226 self.tensors.len()
227 }
228 pub fn is_empty(&self) -> bool {
229 self.tensors.is_empty()
230 }
231
232 pub fn from_tensors(tensors: HashMap<String, (Vec<f32>, Vec<usize>)>) -> Self {
234 Self { tensors }
235 }
236
237 pub fn snapshot_from_path(path: &str) -> Result<F32WeightSnapshot> {
239 let mut wm = Self::from_file(path)?;
240 let keys: Vec<String> = wm.keys().map(|s| s.to_string()).collect();
241 let mut out = HashMap::with_capacity(keys.len());
242 for k in keys {
243 out.insert(k.clone(), wm.take(&k)?);
244 }
245 Ok(out)
246 }
247
248 fn tensor_bytes_to_f32(
249 name: &str,
250 view: safetensors::tensor::TensorView<'_>,
251 ) -> Result<Vec<f32>> {
252 let bytes = view.data();
253 Ok(match view.dtype() {
254 safetensors::Dtype::F32 => bytemuck_cast_f32(bytes),
255 safetensors::Dtype::F16 => bytes
256 .chunks_exact(2)
257 .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
258 .collect(),
259 safetensors::Dtype::BF16 => bytes
260 .chunks_exact(2)
261 .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
262 .collect(),
263 safetensors::Dtype::I64 => bytes
264 .chunks_exact(8)
265 .map(|c| i64::from_le_bytes(c.try_into().unwrap()) as f32)
266 .collect(),
267 safetensors::Dtype::I32 => bytes
268 .chunks_exact(4)
269 .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
270 .collect(),
271 safetensors::Dtype::C64 => return Ok(vec![]),
272 other => anyhow::bail!("{name}: unsupported dtype {other:?}"),
273 })
274 }
275
276 fn ingest_selected_from_bytes(
277 data: &[u8],
278 want: &HashSet<String>,
279 tensors: &mut HashMap<String, (Vec<f32>, Vec<usize>)>,
280 ) -> Result<()> {
281 let st = safetensors::SafeTensors::deserialize(data).context("parsing safetensors")?;
282 for (name, view) in st.tensors() {
283 if !want.contains(name.as_str()) {
284 continue;
285 }
286 let shape: Vec<usize> = view.shape().to_vec();
287 let f32_data = Self::tensor_bytes_to_f32(name.as_str(), view)?;
288 if f32_data.is_empty() {
289 continue;
290 }
291 tensors.insert(name.to_string(), (f32_data, shape));
292 }
293 Ok(())
294 }
295
296 pub fn from_safetensors_dir_selected(dir: &Path, want: &HashSet<String>) -> Result<Self> {
298 if want.is_empty() {
299 anyhow::bail!("from_safetensors_dir_selected: empty key set");
300 }
301 let index_path = dir.join("model.safetensors.index.json");
302 let mut tensors = HashMap::new();
303 if index_path.is_file() {
304 let index: serde_json::Value = serde_json::from_slice(&std::fs::read(&index_path)?)
305 .context("weight index json")?;
306 let weight_map = index
307 .get("weight_map")
308 .and_then(|m| m.as_object())
309 .context("weight_map in index")?;
310 let mut shard_files: HashSet<String> = HashSet::new();
311 for key in want {
312 if let Some(shard) = weight_map.get(key).and_then(|v| v.as_str()) {
313 shard_files.insert(shard.to_string());
314 }
315 }
316 for shard in shard_files {
317 let path = dir.join(&shard);
318 let data = std::fs::read(&path).with_context(|| format!("reading {path:?}"))?;
319 Self::ingest_selected_from_bytes(&data, want, &mut tensors)?;
320 }
321 } else {
322 for entry in std::fs::read_dir(dir).with_context(|| format!("read_dir {dir:?}"))? {
323 let path = entry?.path();
324 if path.extension().and_then(|s| s.to_str()) != Some("safetensors") {
325 continue;
326 }
327 let data = std::fs::read(&path).with_context(|| format!("reading {path:?}"))?;
328 Self::ingest_selected_from_bytes(&data, want, &mut tensors)?;
329 }
330 }
331 if tensors.is_empty() {
332 anyhow::bail!("no requested tensors found under {dir:?}");
333 }
334 Ok(Self { tensors })
335 }
336
337 pub fn from_safetensors_dir(dir: &Path) -> Result<Self> {
339 let mut merged = HashMap::new();
340 let mut any = false;
341 for entry in std::fs::read_dir(dir).with_context(|| format!("read_dir {dir:?}"))? {
342 let entry = entry?;
343 let path = entry.path();
344 if path.extension().and_then(|s| s.to_str()) != Some("safetensors") {
345 continue;
346 }
347 let part = Self::from_file(
348 path.to_str()
349 .ok_or_else(|| anyhow::anyhow!("non-utf8 path {:?}", path))?,
350 )?;
351 for (k, v) in part.tensors {
352 merged.insert(k, v);
353 }
354 any = true;
355 }
356 if !any {
357 anyhow::bail!("no .safetensors files in {dir:?}");
358 }
359 Ok(Self { tensors: merged })
360 }
361
362 pub fn remap_keys<F>(&mut self, mut f: F)
364 where
365 F: FnMut(String) -> String,
366 {
367 let keys: Vec<String> = self.tensors.keys().cloned().collect();
368 for old in keys {
369 if let Some(v) = self.tensors.remove(&old) {
370 let new = f(old);
371 self.tensors.insert(new, v);
372 }
373 }
374 }
375
376 pub fn get(&self, key: &str) -> Option<(&[f32], &[usize])> {
378 self.tensors
379 .get(key)
380 .map(|(d, s)| (d.as_slice(), s.as_slice()))
381 }
382
383 pub fn merge_add_weight(&mut self, key: &str, delta: &[f32]) -> Result<()> {
385 let entry = self
386 .tensors
387 .get_mut(key)
388 .with_context(|| format!("merge_add_weight: missing {key}"))?;
389 let (data, shape) = entry;
390 ensure!(
391 shape.len() == 2,
392 "merge_add_weight {key}: expected rank-2, got {shape:?}"
393 );
394 ensure!(
395 data.len() == delta.len(),
396 "merge_add_weight {key}: len {} != delta {}",
397 data.len(),
398 delta.len()
399 );
400 for (d, s) in data.iter_mut().zip(delta.iter()) {
401 *d += s;
402 }
403 Ok(())
404 }
405}
406
407fn bytemuck_cast_f32(bytes: &[u8]) -> Vec<f32> {
413 debug_assert!(
414 bytes.len().is_multiple_of(4),
415 "f32 byte slice length must be multiple of 4 (got {})",
416 bytes.len()
417 );
418 if (bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
419 let f32s: &[f32] = bytemuck::cast_slice(bytes);
420 f32s.to_vec()
421 } else {
422 bytes
423 .chunks_exact(4)
424 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
425 .collect()
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn transpose_2d() {
435 let mut wm = WeightMap {
436 tensors: HashMap::from([(
437 "w".to_string(),
438 (vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]),
439 )]),
440 };
441 let (data, shape) = wm.take_transposed("w").unwrap();
442 assert_eq!(shape, vec![3, 2]);
443 assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
445 }
446}