Skip to main content

rlx_models_core/
weight_map.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Safetensors weight loading — standalone, no framework dependency.
17
18use anyhow::{Context, Result, bail, ensure};
19use std::collections::{HashMap, HashSet};
20use std::path::Path;
21
22use crate::gguf_support::{
23    gguf_architecture_from_path, gguf_safetensors_only_hint, resolve_weights_file,
24};
25use crate::weight_loader::WeightLoader;
26use crate::weight_registry::{LoadWeightsOptions, load_weight_map_resolved};
27use rlx_ir::quant::QuantScheme;
28
29/// Packed GGUF weight bytes + scheme + logical shape.
30pub type PackedWeightTensor = (Vec<u8>, QuantScheme, Vec<usize>);
31/// Named packed tensor (sidecar list from [`WeightMap::drain_loader`]).
32pub type NamedPackedWeight = (String, Vec<u8>, QuantScheme, Vec<usize>);
33/// F32 tensor snapshot (`name → (data, shape)`).
34pub type F32WeightSnapshot = HashMap<String, (Vec<f32>, Vec<usize>)>;
35
36/// How [`WeightMap::drain_loader`] / [`WeightMap::from_weight_loader`] handle leftovers.
37#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
38pub enum WeightDrainPolicy {
39    #[default]
40    AllF32,
41    /// Log a warning when tensors remain after drain.
42    AllF32WarnUnused,
43    /// Fail if any tensor was not taken.
44    AllF32StrictUnused,
45}
46
47/// Map of tensor name → (f32 data, shape).
48pub struct WeightMap {
49    tensors: HashMap<String, (Vec<f32>, Vec<usize>)>,
50}
51
52impl WeightMap {
53    /// Drain every tensor from any [`WeightLoader`] (safetensors or GGUF).
54    pub fn from_weight_loader(loader: &mut dyn WeightLoader) -> Result<Self> {
55        Self::drain_loader(loader, WeightDrainPolicy::AllF32).map(|(m, _)| m)
56    }
57
58    /// Force-dequantize every tensor (including K-quants) into F32 and
59    /// drop it in the map. Use when a family runner doesn't have a
60    /// packed-matmul lowering yet but still wants to load GGUFs whose
61    /// trunk weights are K-quant. Trades memory (4× larger than the
62    /// packed bytes) for correctness — every tensor goes through
63    /// `WeightLoader::take(...)` which dequantizes on the fly.
64    pub fn from_weight_loader_dequant_all(loader: &mut dyn WeightLoader) -> Result<Self> {
65        let keys = loader.remaining_keys();
66        let mut tensors = HashMap::with_capacity(keys.len());
67        for key in keys {
68            let (data, shape) = loader.take(&key)?;
69            tensors.insert(key, (data, shape));
70        }
71        Ok(Self { tensors })
72    }
73
74    /// Drain with policy; returns packed K-quants separately when the loader supports `take_packed`.
75    pub fn drain_loader(
76        loader: &mut dyn WeightLoader,
77        policy: WeightDrainPolicy,
78    ) -> Result<(Self, Vec<NamedPackedWeight>)> {
79        let keys = loader.remaining_keys();
80        let mut tensors = HashMap::with_capacity(keys.len());
81        let mut packed = Vec::new();
82        for key in keys {
83            if let Some((bytes, scheme, shape)) = loader.take_packed(&key)? {
84                packed.push((key, bytes, scheme, shape));
85                continue;
86            }
87            let (data, shape) = loader.take(&key)?;
88            tensors.insert(key, (data, shape));
89        }
90        let left = loader.remaining_keys();
91        match policy {
92            WeightDrainPolicy::AllF32 => {}
93            WeightDrainPolicy::AllF32WarnUnused if !left.is_empty() => {
94                eprintln!(
95                    "[rlx-core] weight drain: {} unused tensors (format={})",
96                    left.len(),
97                    loader.format_id()
98                );
99                for k in left.iter().take(8) {
100                    eprintln!("  unused: {k}");
101                }
102                if left.len() > 8 {
103                    eprintln!("  … and {} more", left.len() - 8);
104                }
105            }
106            WeightDrainPolicy::AllF32StrictUnused if !left.is_empty() => {
107                bail!(
108                    "weight drain left {} unused tensors (format={}): {:?}",
109                    left.len(),
110                    loader.format_id(),
111                    &left[..left.len().min(5)]
112                );
113            }
114            _ => {}
115        }
116        Ok((Self { tensors }, packed))
117    }
118
119    /// Resolve a file or weights directory, then load (safetensors or GGUF).
120    pub fn from_resolved_path(path: &Path) -> Result<Self> {
121        let file = resolve_weights_file(path)?;
122        Self::from_resolved_file(&file)
123    }
124
125    /// Resolve path; reject `.gguf` with a hint naming the right runner.
126    pub fn from_resolved_safetensors_only(path: &Path, runner: &str) -> Result<Self> {
127        let file = resolve_weights_file(path)?;
128        if file.extension().and_then(|s| s.to_str()) == Some("gguf") {
129            let arch = gguf_architecture_from_path(&file)?;
130            bail!(gguf_safetensors_only_hint(runner, &file, &arch));
131        }
132        Self::from_resolved_file(&file)
133    }
134
135    fn from_resolved_file(file: &Path) -> Result<Self> {
136        load_weight_map_resolved(file, LoadWeightsOptions::map()).map(|(_, m)| m)
137    }
138
139    /// Load weights from a safetensors file. Auto-converts bf16/f16 to f32.
140    pub fn from_file(path: &str) -> Result<Self> {
141        Self::from_file_excluding(path, &HashSet::new())
142    }
143
144    /// Load weights, skipping tensor names present in `exclude` (saves RAM when
145    /// bf16/NVFP4 linears are loaded separately for GPU upload).
146    pub fn from_file_excluding(path: &str, exclude: &HashSet<String>) -> Result<Self> {
147        let data = std::fs::read(path).with_context(|| format!("reading {path}"))?;
148        let st =
149            safetensors::SafeTensors::deserialize(&data).with_context(|| "parsing safetensors")?;
150
151        let mut tensors = HashMap::new();
152        for (name, view) in st.tensors() {
153            if exclude.contains(name.as_str()) {
154                continue;
155            }
156            let shape: Vec<usize> = view.shape().to_vec();
157            let bytes = view.data();
158            let f32_data = match view.dtype() {
159                safetensors::Dtype::F32 => bytemuck_cast_f32(bytes),
160                safetensors::Dtype::F16 => bytes
161                    .chunks_exact(2)
162                    .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
163                    .collect(),
164                safetensors::Dtype::BF16 => bytes
165                    .chunks_exact(2)
166                    .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
167                    .collect(),
168                safetensors::Dtype::I64 => bytes
169                    .chunks_exact(8)
170                    .map(|c| i64::from_le_bytes(c.try_into().unwrap()) as f32)
171                    .collect(),
172                safetensors::Dtype::I32 => bytes
173                    .chunks_exact(4)
174                    .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
175                    .collect(),
176                safetensors::Dtype::C64 => {
177                    // Some checkpoints (SAM3) include complex RoPE caches
178                    // such as `freqs_cis`. Native code regenerates/handles
179                    // those separately; keep loading usable for the real
180                    // float weights instead of rejecting the entire file.
181                    continue;
182                }
183                other => anyhow::bail!("unsupported dtype: {other:?}"),
184            };
185            tensors.insert(name.to_string(), (f32_data, shape));
186        }
187
188        Ok(Self { tensors })
189    }
190
191    /// Take a tensor by name (removes from map). Returns (data, shape).
192    pub fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
193        self.tensors
194            .remove(key)
195            .ok_or_else(|| anyhow::anyhow!("weight not found: {key}"))
196    }
197
198    /// Take and transpose a 2D weight: [out, in] → [in, out] for row-major matmul.
199    pub fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
200        let (data, shape) = self.take(key)?;
201        if shape.len() != 2 {
202            anyhow::bail!("transpose requires 2D, got {shape:?}");
203        }
204        let (rows, cols) = (shape[0], shape[1]);
205        let mut transposed = vec![0f32; data.len()];
206        for i in 0..rows {
207            for j in 0..cols {
208                transposed[j * rows + i] = data[i * cols + j];
209            }
210        }
211        Ok((transposed, vec![cols, rows]))
212    }
213
214    /// Check if a key exists.
215    pub fn has(&self, key: &str) -> bool {
216        self.tensors.contains_key(key)
217    }
218
219    /// List all keys.
220    pub fn keys(&self) -> impl Iterator<Item = &str> {
221        self.tensors.keys().map(|s| s.as_str())
222    }
223
224    /// Number of tensors remaining.
225    pub fn len(&self) -> usize {
226        self.tensors.len()
227    }
228    pub fn is_empty(&self) -> bool {
229        self.tensors.is_empty()
230    }
231
232    /// Create from pre-built HashMap (for testing without safetensors files).
233    pub fn from_tensors(tensors: HashMap<String, (Vec<f32>, Vec<usize>)>) -> Self {
234        Self { tensors }
235    }
236
237    /// Drain all tensors into a snapshot map (for runners that rebuild graphs per shape).
238    pub fn snapshot_from_path(path: &str) -> Result<F32WeightSnapshot> {
239        let mut wm = Self::from_file(path)?;
240        let keys: Vec<String> = wm.keys().map(|s| s.to_string()).collect();
241        let mut out = HashMap::with_capacity(keys.len());
242        for k in keys {
243            out.insert(k.clone(), wm.take(&k)?);
244        }
245        Ok(out)
246    }
247
248    fn tensor_bytes_to_f32(
249        name: &str,
250        view: safetensors::tensor::TensorView<'_>,
251    ) -> Result<Vec<f32>> {
252        let bytes = view.data();
253        Ok(match view.dtype() {
254            safetensors::Dtype::F32 => bytemuck_cast_f32(bytes),
255            safetensors::Dtype::F16 => bytes
256                .chunks_exact(2)
257                .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
258                .collect(),
259            safetensors::Dtype::BF16 => bytes
260                .chunks_exact(2)
261                .map(|c| half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
262                .collect(),
263            safetensors::Dtype::I64 => bytes
264                .chunks_exact(8)
265                .map(|c| i64::from_le_bytes(c.try_into().unwrap()) as f32)
266                .collect(),
267            safetensors::Dtype::I32 => bytes
268                .chunks_exact(4)
269                .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]) as f32)
270                .collect(),
271            safetensors::Dtype::C64 => return Ok(vec![]),
272            other => anyhow::bail!("{name}: unsupported dtype {other:?}"),
273        })
274    }
275
276    fn ingest_selected_from_bytes(
277        data: &[u8],
278        want: &HashSet<String>,
279        tensors: &mut HashMap<String, (Vec<f32>, Vec<usize>)>,
280    ) -> Result<()> {
281        let st = safetensors::SafeTensors::deserialize(data).context("parsing safetensors")?;
282        for (name, view) in st.tensors() {
283            if !want.contains(name.as_str()) {
284                continue;
285            }
286            let shape: Vec<usize> = view.shape().to_vec();
287            let f32_data = Self::tensor_bytes_to_f32(name.as_str(), view)?;
288            if f32_data.is_empty() {
289                continue;
290            }
291            tensors.insert(name.to_string(), (f32_data, shape));
292        }
293        Ok(())
294    }
295
296    /// Load only tensors whose names appear in `want` (HF sharded checkpoints).
297    pub fn from_safetensors_dir_selected(dir: &Path, want: &HashSet<String>) -> Result<Self> {
298        if want.is_empty() {
299            anyhow::bail!("from_safetensors_dir_selected: empty key set");
300        }
301        let index_path = dir.join("model.safetensors.index.json");
302        let mut tensors = HashMap::new();
303        if index_path.is_file() {
304            let index: serde_json::Value = serde_json::from_slice(&std::fs::read(&index_path)?)
305                .context("weight index json")?;
306            let weight_map = index
307                .get("weight_map")
308                .and_then(|m| m.as_object())
309                .context("weight_map in index")?;
310            let mut shard_files: HashSet<String> = HashSet::new();
311            for key in want {
312                if let Some(shard) = weight_map.get(key).and_then(|v| v.as_str()) {
313                    shard_files.insert(shard.to_string());
314                }
315            }
316            for shard in shard_files {
317                let path = dir.join(&shard);
318                let data = std::fs::read(&path).with_context(|| format!("reading {path:?}"))?;
319                Self::ingest_selected_from_bytes(&data, want, &mut tensors)?;
320            }
321        } else {
322            for entry in std::fs::read_dir(dir).with_context(|| format!("read_dir {dir:?}"))? {
323                let path = entry?.path();
324                if path.extension().and_then(|s| s.to_str()) != Some("safetensors") {
325                    continue;
326                }
327                let data = std::fs::read(&path).with_context(|| format!("reading {path:?}"))?;
328                Self::ingest_selected_from_bytes(&data, want, &mut tensors)?;
329            }
330        }
331        if tensors.is_empty() {
332            anyhow::bail!("no requested tensors found under {dir:?}");
333        }
334        Ok(Self { tensors })
335    }
336
337    /// Load and merge every `*.safetensors` file in `dir` (e.g. HF `text_encoder/`).
338    pub fn from_safetensors_dir(dir: &Path) -> Result<Self> {
339        let mut merged = HashMap::new();
340        let mut any = false;
341        for entry in std::fs::read_dir(dir).with_context(|| format!("read_dir {dir:?}"))? {
342            let entry = entry?;
343            let path = entry.path();
344            if path.extension().and_then(|s| s.to_str()) != Some("safetensors") {
345                continue;
346            }
347            let part = Self::from_file(
348                path.to_str()
349                    .ok_or_else(|| anyhow::anyhow!("non-utf8 path {:?}", path))?,
350            )?;
351            for (k, v) in part.tensors {
352                merged.insert(k, v);
353            }
354            any = true;
355        }
356        if !any {
357            anyhow::bail!("no .safetensors files in {dir:?}");
358        }
359        Ok(Self { tensors: merged })
360    }
361
362    /// Rename keys in-place (e.g. strip `model.` HuggingFace prefix).
363    pub fn remap_keys<F>(&mut self, mut f: F)
364    where
365        F: FnMut(String) -> String,
366    {
367        let keys: Vec<String> = self.tensors.keys().cloned().collect();
368        for old in keys {
369            if let Some(v) = self.tensors.remove(&old) {
370                let new = f(old);
371                self.tensors.insert(new, v);
372            }
373        }
374    }
375
376    /// Borrow tensor data + shape without removing from the map.
377    pub fn get(&self, key: &str) -> Option<(&[f32], &[usize])> {
378        self.tensors
379            .get(key)
380            .map(|(d, s)| (d.as_slice(), s.as_slice()))
381    }
382
383    /// Element-wise add `delta` into an existing rank-2 weight (PyTorch `[out, in]` layout).
384    pub fn merge_add_weight(&mut self, key: &str, delta: &[f32]) -> Result<()> {
385        let entry = self
386            .tensors
387            .get_mut(key)
388            .with_context(|| format!("merge_add_weight: missing {key}"))?;
389        let (data, shape) = entry;
390        ensure!(
391            shape.len() == 2,
392            "merge_add_weight {key}: expected rank-2, got {shape:?}"
393        );
394        ensure!(
395            data.len() == delta.len(),
396            "merge_add_weight {key}: len {} != delta {}",
397            data.len(),
398            delta.len()
399        );
400        for (d, s) in data.iter_mut().zip(delta.iter()) {
401            *d += s;
402        }
403        Ok(())
404    }
405}
406
407/// Convert a raw byte slice to a `Vec<f32>`. Safetensors stores tensor
408/// data at arbitrary byte offsets — when an f32 tensor doesn't land on
409/// a 4-byte boundary, `bytemuck::cast_slice` panics with
410/// `TargetAlignmentGreaterAndInputNotAligned`. SAM ViT-B is one such
411/// file. Fall back to a manual little-endian decode in that case.
412fn bytemuck_cast_f32(bytes: &[u8]) -> Vec<f32> {
413    debug_assert!(
414        bytes.len().is_multiple_of(4),
415        "f32 byte slice length must be multiple of 4 (got {})",
416        bytes.len()
417    );
418    if (bytes.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()) {
419        let f32s: &[f32] = bytemuck::cast_slice(bytes);
420        f32s.to_vec()
421    } else {
422        bytes
423            .chunks_exact(4)
424            .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
425            .collect()
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn transpose_2d() {
435        let mut wm = WeightMap {
436            tensors: HashMap::from([(
437                "w".to_string(),
438                (vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]),
439            )]),
440        };
441        let (data, shape) = wm.take_transposed("w").unwrap();
442        assert_eq!(shape, vec![3, 2]);
443        // Original: [[1,2,3],[4,5,6]] → Transposed: [[1,4],[2,5],[3,6]]
444        assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
445    }
446}