Skip to main content

rlx_ocr/weights/
load.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! Safetensors → [`WeightMap`] for native RLX graph build.
5
6use anyhow::{Context, Result, bail};
7use memmap2::{Mmap, MmapOptions};
8use rlx_core::weight_map::WeightMap;
9use safetensors::SafeTensors;
10use std::fs::File;
11use std::path::{Path, PathBuf};
12use std::sync::Mutex;
13
14use super::paths::is_rten_checkpoint;
15
16/// Mmap-backed safetensors file; reuse across per-width graph builds.
17pub struct SafetensorsFile {
18    path: PathBuf,
19    mmap: Mutex<Option<Mmap>>,
20}
21
22impl SafetensorsFile {
23    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
24        Ok(Self {
25            path: path.as_ref().to_path_buf(),
26            mmap: Mutex::new(None),
27        })
28    }
29
30    fn with_mmap<R>(&self, f: impl FnOnce(&Mmap) -> Result<R>) -> Result<R> {
31        let mut guard = self
32            .mmap
33            .lock()
34            .map_err(|_| anyhow::anyhow!("safetensors mmap lock poisoned"))?;
35        if guard.is_none() {
36            let file = File::open(&self.path).with_context(|| format!("open {:?}", self.path))?;
37            *guard = Some(unsafe { MmapOptions::new().map(&file)? });
38        }
39        f(guard.as_ref().unwrap())
40    }
41
42    /// Fresh [`WeightMap`] for graph build (drains keys from a parse of the mmap).
43    pub fn weight_map(&self) -> Result<WeightMap> {
44        self.with_mmap(load_safetensors_weights_from_mmap)
45    }
46}
47
48/// Load a `.safetensors` file via mmap-backed read into f32 [`WeightMap`] tensors.
49pub fn load_safetensors(path: &Path) -> Result<WeightMap> {
50    SafetensorsFile::open(path)?.weight_map()
51}
52
53pub(crate) fn load_safetensors_weights_from_mmap(mmap: &Mmap) -> Result<WeightMap> {
54    let mut wm = drain_safetensors_bytes(mmap)?;
55    strip_graph_scalars(&mut wm);
56    Ok(wm)
57}
58
59fn drain_safetensors_bytes(data: &[u8]) -> Result<WeightMap> {
60    let st = SafeTensors::deserialize(data).context("parse safetensors")?;
61    let mut tensors = std::collections::HashMap::new();
62    for (name, view) in st.tensors() {
63        let shape: Vec<usize> = view.shape().to_vec();
64        let bytes = view.data();
65        let f32_data = match view.dtype() {
66            safetensors::Dtype::F32 => {
67                if bytes.len() % 4 != 0 {
68                    bail!("{name}: invalid f32 byte length");
69                }
70                bytes
71                    .chunks_exact(4)
72                    .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
73                    .collect()
74            }
75            safetensors::Dtype::F16 => bytes
76                .chunks_exact(2)
77                .map(|c| ::half::f16::from_le_bytes([c[0], c[1]]).to_f32())
78                .collect(),
79            safetensors::Dtype::BF16 => bytes
80                .chunks_exact(2)
81                .map(|c| ::half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
82                .collect(),
83            other => bail!("{name}: unsupported dtype {other:?}"),
84        };
85        tensors.insert(name.to_string(), (f32_data, shape));
86    }
87    Ok(WeightMap::from_tensors(tensors))
88}
89
90/// Drop RTen ONNX scalar nodes (slice/pad helpers) not used by the native graph.
91fn strip_graph_scalars(wm: &mut WeightMap) {
92    let remove: Vec<String> = wm
93        .keys()
94        .filter(|k| k.starts_with('/') || k.contains("Constant") || k.contains("Unsqueeze"))
95        .map(str::to_string)
96        .collect();
97    for k in remove {
98        let _ = wm.take(&k);
99    }
100}
101
102/// Load weights for RLX graph build (safetensors only).
103pub fn load_safetensors_weights(path: &Path) -> Result<WeightMap> {
104    if is_rten_checkpoint(path) {
105        bail!(
106            "RLX graph weights require .safetensors ({:?}); run `rlx-ocr-convert` on .rten checkpoints",
107            path
108        );
109    }
110    SafetensorsFile::open(path)?.weight_map()
111}
112
113/// Alias for [`load_safetensors_weights`].
114pub fn load_rlx_weights(path: &Path) -> Result<WeightMap> {
115    load_safetensors_weights(path)
116}