Skip to main content

rlx_umap/
serialize.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Weight and model I/O.
17//!
18//! **Default:** [safetensors](https://huggingface.co/docs/safetensors) (`.safetensors`) with
19//! `rlx_umap.*` string metadata — see [`crate::model_io`].
20//!
21//! **Also:** GGUF F32 (`.gguf`, feature `io-gguf`), legacy `.ruama` v1–v4 (load only for old files).
22
23use std::io::{Read, Write};
24use std::path::{Path, PathBuf};
25
26use crate::config::UmapConfig;
27use crate::encoder::mlp::ModelSpec;
28use crate::utils::NormStats;
29use crate::weights::WeightStore;
30
31const MAGIC: &[u8; 4] = b"RUMA";
32const VERSION_V1: u32 = 1;
33const VERSION_V2: u32 = 2;
34const VERSION_V3: u32 = 3;
35const VERSION_V4: u32 = 4;
36
37/// Training layout metadata stored alongside weights.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct ModelMetadata {
40    pub n_train: usize,
41    pub n_features: usize,
42    pub n_pos: usize,
43    pub n_neg: usize,
44}
45
46/// Everything needed to reconstruct a [`crate::fitted::FittedUmap`].
47#[derive(Debug, Clone)]
48pub struct LoadedModel {
49    pub weights: WeightStore,
50    pub meta: ModelMetadata,
51    pub norm: NormStats,
52    pub config: Option<UmapConfig>,
53}
54
55/// View for writing a full fitted model (v4).
56pub struct SaveBundle<'a> {
57    pub weights: &'a WeightStore,
58    pub meta: ModelMetadata,
59    pub norm: &'a NormStats,
60    pub config: &'a UmapConfig,
61}
62
63/// Save weights only (`.safetensors` / `.gguf` by extension).
64pub fn save_weights(
65    w: &WeightStore,
66    spec: &ModelSpec,
67    path: impl AsRef<Path>,
68) -> std::io::Result<()> {
69    crate::model_io::save_weights(w, spec, path)
70}
71
72/// Load encoder weights (safetensors, GGUF, or legacy `.ruama`).
73pub fn load_weights(path: impl AsRef<Path>) -> std::io::Result<WeightStore> {
74    crate::model_io::load_weights(path)
75}
76
77/// Save weights to legacy `.ruama` v1 (backward compatibility).
78pub(crate) fn save_weights_ruama(w: &WeightStore, path: impl AsRef<Path>) -> std::io::Result<()> {
79    write_bundle(path, w, None, None, None)
80}
81
82/// Load weights from a legacy `.ruama` v1 file only.
83pub(crate) fn load_weights_ruama(path: impl AsRef<Path>) -> std::io::Result<WeightStore> {
84    let mut file = std::fs::File::open(path.as_ref())?;
85    let version = read_header(&mut file)?;
86    if version == VERSION_V1 {
87        let count = read_count(&mut file)?;
88        return read_tensors(&mut file, count);
89    }
90    drop(file);
91    load_bundle(path).map(|b| b.weights)
92}
93
94/// Save full model (weights + metadata + norm + config).
95pub fn save_model(bundle: SaveBundle<'_>, path: impl AsRef<Path>) -> std::io::Result<()> {
96    crate::model_io::save_model(bundle, path)
97}
98
99/// Load a full model (safetensors, GGUF, or legacy `.ruama`).
100pub fn load_model(path: impl AsRef<Path>) -> std::io::Result<LoadedModel> {
101    crate::model_io::load_model(path)
102}
103
104/// Legacy `.ruama` v2+ bundle (used by [`crate::model_io`]).
105pub(crate) fn load_legacy_ruama(path: impl AsRef<Path>) -> std::io::Result<LoadedModel> {
106    load_bundle(path)
107}
108
109fn write_bytes(file: &mut std::fs::File, data: &[u8]) -> std::io::Result<()> {
110    file.write_all(&(data.len() as u32).to_le_bytes())?;
111    file.write_all(data)?;
112    Ok(())
113}
114
115fn read_bytes(file: &mut std::fs::File) -> std::io::Result<Vec<u8>> {
116    let mut len_buf = [0u8; 4];
117    file.read_exact(&mut len_buf)?;
118    let len = u32::from_le_bytes(len_buf) as usize;
119    let mut data = vec![0u8; len];
120    file.read_exact(&mut data)?;
121    Ok(data)
122}
123
124fn write_f64_slice(file: &mut std::fs::File, data: &[f64]) -> std::io::Result<()> {
125    file.write_all(&(data.len() as u32).to_le_bytes())?;
126    for &v in data {
127        file.write_all(&v.to_le_bytes())?;
128    }
129    Ok(())
130}
131
132fn read_f64_slice(file: &mut std::fs::File, expect: usize) -> std::io::Result<Vec<f64>> {
133    let mut len_buf = [0u8; 4];
134    file.read_exact(&mut len_buf)?;
135    let len = u32::from_le_bytes(len_buf) as usize;
136    if len != expect {
137        return Err(std::io::Error::new(
138            std::io::ErrorKind::InvalidData,
139            format!("expected {expect} norm values, got {len}"),
140        ));
141    }
142    let mut out = vec![0f64; len];
143    for slot in &mut out {
144        let mut b = [0u8; 8];
145        file.read_exact(&mut b)?;
146        *slot = f64::from_le_bytes(b);
147    }
148    Ok(out)
149}
150
151fn write_bundle(
152    path: impl AsRef<Path>,
153    w: &WeightStore,
154    meta: Option<ModelMetadata>,
155    norm: Option<&NormStats>,
156    config: Option<&UmapConfig>,
157) -> std::io::Result<()> {
158    let mut names: Vec<String> = w.0.keys().cloned().collect();
159    names.sort();
160    let mut file = std::fs::File::create(path.as_ref())?;
161    file.write_all(MAGIC)?;
162    let version = if config.is_some() {
163        VERSION_V4
164    } else if norm.is_some() {
165        VERSION_V3
166    } else if meta.is_some() {
167        VERSION_V2
168    } else {
169        VERSION_V1
170    };
171    file.write_all(&version.to_le_bytes())?;
172    if let Some(m) = meta {
173        file.write_all(&(m.n_train as u32).to_le_bytes())?;
174        file.write_all(&(m.n_features as u32).to_le_bytes())?;
175        file.write_all(&(m.n_pos as u32).to_le_bytes())?;
176        file.write_all(&(m.n_neg as u32).to_le_bytes())?;
177    }
178    if let Some(n) = norm {
179        write_f64_slice(&mut file, &n.mean)?;
180        write_f64_slice(&mut file, &n.std)?;
181    }
182    if let Some(cfg) = config {
183        let json = serde_json::to_vec(cfg)
184            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
185        write_bytes(&mut file, &json)?;
186    }
187    let count = names.len() as u32;
188    file.write_all(&count.to_le_bytes())?;
189    for name in &names {
190        let data = &w.0[name];
191        let name_bytes = name.as_bytes();
192        file.write_all(&(name_bytes.len() as u32).to_le_bytes())?;
193        file.write_all(name_bytes)?;
194        file.write_all(&(data.len() as u32).to_le_bytes())?;
195        for &v in data {
196            file.write_all(&v.to_le_bytes())?;
197        }
198    }
199    Ok(())
200}
201
202fn read_header(file: &mut std::fs::File) -> std::io::Result<u32> {
203    let mut magic = [0u8; 4];
204    file.read_exact(&mut magic)?;
205    if &magic != MAGIC {
206        return Err(std::io::Error::new(
207            std::io::ErrorKind::InvalidData,
208            "not an rlx-umap weight file (expected RUMA magic)",
209        ));
210    }
211    let mut word_buf = [0u8; 4];
212    file.read_exact(&mut word_buf)?;
213    Ok(u32::from_le_bytes(word_buf))
214}
215
216fn load_bundle(path: impl AsRef<Path>) -> std::io::Result<LoadedModel> {
217    let mut file = std::fs::File::open(path.as_ref())?;
218    let version = read_header(&mut file)?;
219
220    let (meta, norm, config, count) = match version {
221        VERSION_V4 => {
222            let m = read_meta(&mut file)?;
223            let mean = read_f64_slice(&mut file, m.n_features)?;
224            let std = read_f64_slice(&mut file, m.n_features)?;
225            let json = read_bytes(&mut file)?;
226            let cfg: UmapConfig = serde_json::from_slice(&json)
227                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
228            let count = read_count(&mut file)?;
229            (Some(m), Some(NormStats { mean, std }), Some(cfg), count)
230        }
231        VERSION_V3 => {
232            let m = read_meta(&mut file)?;
233            let mean = read_f64_slice(&mut file, m.n_features)?;
234            let std = read_f64_slice(&mut file, m.n_features)?;
235            let count = read_count(&mut file)?;
236            (Some(m), Some(NormStats { mean, std }), None, count)
237        }
238        VERSION_V2 => {
239            let m = read_meta(&mut file)?;
240            let count = read_count(&mut file)?;
241            (Some(m), None, None, count)
242        }
243        VERSION_V1 => {
244            let count = read_count(&mut file)?;
245            (None, None, None, count)
246        }
247        _ => {
248            // Legacy: first u32 after magic was tensor count.
249            let count = version as usize;
250            (None, None, None, count)
251        }
252    };
253
254    let meta = meta.ok_or_else(|| {
255        std::io::Error::new(
256            std::io::ErrorKind::InvalidData,
257            "file has weights only — use load_weights or re-save with save_model",
258        )
259    })?;
260
261    let norm = norm.unwrap_or_else(|| NormStats {
262        mean: vec![0.0; meta.n_features],
263        std: vec![1.0; meta.n_features],
264    });
265
266    let weights = read_tensors(&mut file, count)?;
267
268    Ok(LoadedModel {
269        weights,
270        meta,
271        norm,
272        config,
273    })
274}
275
276fn read_tensors(file: &mut std::fs::File, count: usize) -> std::io::Result<WeightStore> {
277    let mut weights = WeightStore::default();
278    for _ in 0..count {
279        let mut nlen_buf = [0u8; 4];
280        file.read_exact(&mut nlen_buf)?;
281        let nlen = u32::from_le_bytes(nlen_buf) as usize;
282        let mut name_bytes = vec![0u8; nlen];
283        file.read_exact(&mut name_bytes)?;
284        let name = String::from_utf8(name_bytes)
285            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
286        let mut dlen_buf = [0u8; 4];
287        file.read_exact(&mut dlen_buf)?;
288        let dlen = u32::from_le_bytes(dlen_buf) as usize;
289        let mut data = vec![0f32; dlen];
290        for slot in &mut data {
291            let mut b = [0u8; 4];
292            file.read_exact(&mut b)?;
293            *slot = f32::from_le_bytes(b);
294        }
295        weights.0.insert(name, data);
296    }
297    Ok(weights)
298}
299
300fn read_meta(file: &mut std::fs::File) -> std::io::Result<ModelMetadata> {
301    let mut buf = [0u8; 4];
302    file.read_exact(&mut buf)?;
303    let n_train = u32::from_le_bytes(buf) as usize;
304    file.read_exact(&mut buf)?;
305    let n_features = u32::from_le_bytes(buf) as usize;
306    file.read_exact(&mut buf)?;
307    let n_pos = u32::from_le_bytes(buf) as usize;
308    file.read_exact(&mut buf)?;
309    let n_neg = u32::from_le_bytes(buf) as usize;
310    Ok(ModelMetadata {
311        n_train,
312        n_features,
313        n_pos,
314        n_neg,
315    })
316}
317
318fn read_count(file: &mut std::fs::File) -> std::io::Result<usize> {
319    let mut count_buf = [0u8; 4];
320    file.read_exact(&mut count_buf)?;
321    Ok(u32::from_le_bytes(count_buf) as usize)
322}
323
324/// Suggested path: `dir/model.safetensors`.
325pub fn model_path(dir: impl AsRef<Path>, stem: &str) -> PathBuf {
326    crate::model_io::model_path(dir, stem)
327}