1use anyhow::{Context, Result, bail};
7use memmap2::{Mmap, MmapOptions};
8use rlx_core::weight_map::WeightMap;
9use safetensors::SafeTensors;
10use std::fs::File;
11use std::path::{Path, PathBuf};
12use std::sync::Mutex;
13
14use super::paths::is_rten_checkpoint;
15
16pub struct SafetensorsFile {
18 path: PathBuf,
19 mmap: Mutex<Option<Mmap>>,
20}
21
22impl SafetensorsFile {
23 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
24 Ok(Self {
25 path: path.as_ref().to_path_buf(),
26 mmap: Mutex::new(None),
27 })
28 }
29
30 fn with_mmap<R>(&self, f: impl FnOnce(&Mmap) -> Result<R>) -> Result<R> {
31 let mut guard = self
32 .mmap
33 .lock()
34 .map_err(|_| anyhow::anyhow!("safetensors mmap lock poisoned"))?;
35 if guard.is_none() {
36 let file = File::open(&self.path).with_context(|| format!("open {:?}", self.path))?;
37 *guard = Some(unsafe { MmapOptions::new().map(&file)? });
38 }
39 f(guard.as_ref().unwrap())
40 }
41
42 pub fn weight_map(&self) -> Result<WeightMap> {
44 self.with_mmap(load_safetensors_weights_from_mmap)
45 }
46}
47
48pub fn load_safetensors(path: &Path) -> Result<WeightMap> {
50 SafetensorsFile::open(path)?.weight_map()
51}
52
53pub(crate) fn load_safetensors_weights_from_mmap(mmap: &Mmap) -> Result<WeightMap> {
54 let mut wm = drain_safetensors_bytes(mmap)?;
55 strip_graph_scalars(&mut wm);
56 Ok(wm)
57}
58
59fn drain_safetensors_bytes(data: &[u8]) -> Result<WeightMap> {
60 let st = SafeTensors::deserialize(data).context("parse safetensors")?;
61 let mut tensors = std::collections::HashMap::new();
62 for (name, view) in st.tensors() {
63 let shape: Vec<usize> = view.shape().to_vec();
64 let bytes = view.data();
65 let f32_data = match view.dtype() {
66 safetensors::Dtype::F32 => {
67 if bytes.len() % 4 != 0 {
68 bail!("{name}: invalid f32 byte length");
69 }
70 bytes
71 .chunks_exact(4)
72 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
73 .collect()
74 }
75 safetensors::Dtype::F16 => bytes
76 .chunks_exact(2)
77 .map(|c| ::half::f16::from_le_bytes([c[0], c[1]]).to_f32())
78 .collect(),
79 safetensors::Dtype::BF16 => bytes
80 .chunks_exact(2)
81 .map(|c| ::half::bf16::from_le_bytes([c[0], c[1]]).to_f32())
82 .collect(),
83 other => bail!("{name}: unsupported dtype {other:?}"),
84 };
85 tensors.insert(name.to_string(), (f32_data, shape));
86 }
87 Ok(WeightMap::from_tensors(tensors))
88}
89
90fn strip_graph_scalars(wm: &mut WeightMap) {
92 let remove: Vec<String> = wm
93 .keys()
94 .filter(|k| k.starts_with('/') || k.contains("Constant") || k.contains("Unsqueeze"))
95 .map(str::to_string)
96 .collect();
97 for k in remove {
98 let _ = wm.take(&k);
99 }
100}
101
102pub fn load_safetensors_weights(path: &Path) -> Result<WeightMap> {
104 if is_rten_checkpoint(path) {
105 bail!(
106 "RLX graph weights require .safetensors ({:?}); run `rlx-ocr-convert` on .rten checkpoints",
107 path
108 );
109 }
110 SafetensorsFile::open(path)?.weight_map()
111}
112
113pub fn load_rlx_weights(path: &Path) -> Result<WeightMap> {
115 load_safetensors_weights(path)
116}