1use std::io::{Read, Write};
24use std::path::{Path, PathBuf};
25
26use crate::config::UmapConfig;
27use crate::encoder::mlp::ModelSpec;
28use crate::utils::NormStats;
29use crate::weights::WeightStore;
30
31const MAGIC: &[u8; 4] = b"RUMA";
32const VERSION_V1: u32 = 1;
33const VERSION_V2: u32 = 2;
34const VERSION_V3: u32 = 3;
35const VERSION_V4: u32 = 4;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct ModelMetadata {
40 pub n_train: usize,
41 pub n_features: usize,
42 pub n_pos: usize,
43 pub n_neg: usize,
44}
45
46#[derive(Debug, Clone)]
48pub struct LoadedModel {
49 pub weights: WeightStore,
50 pub meta: ModelMetadata,
51 pub norm: NormStats,
52 pub config: Option<UmapConfig>,
53}
54
55pub struct SaveBundle<'a> {
57 pub weights: &'a WeightStore,
58 pub meta: ModelMetadata,
59 pub norm: &'a NormStats,
60 pub config: &'a UmapConfig,
61}
62
63pub fn save_weights(
65 w: &WeightStore,
66 spec: &ModelSpec,
67 path: impl AsRef<Path>,
68) -> std::io::Result<()> {
69 crate::model_io::save_weights(w, spec, path)
70}
71
72pub fn load_weights(path: impl AsRef<Path>) -> std::io::Result<WeightStore> {
74 crate::model_io::load_weights(path)
75}
76
77pub(crate) fn save_weights_ruama(w: &WeightStore, path: impl AsRef<Path>) -> std::io::Result<()> {
79 write_bundle(path, w, None, None, None)
80}
81
82pub(crate) fn load_weights_ruama(path: impl AsRef<Path>) -> std::io::Result<WeightStore> {
84 let mut file = std::fs::File::open(path.as_ref())?;
85 let version = read_header(&mut file)?;
86 if version == VERSION_V1 {
87 let count = read_count(&mut file)?;
88 return read_tensors(&mut file, count);
89 }
90 drop(file);
91 load_bundle(path).map(|b| b.weights)
92}
93
94pub fn save_model(bundle: SaveBundle<'_>, path: impl AsRef<Path>) -> std::io::Result<()> {
96 crate::model_io::save_model(bundle, path)
97}
98
99pub fn load_model(path: impl AsRef<Path>) -> std::io::Result<LoadedModel> {
101 crate::model_io::load_model(path)
102}
103
104pub(crate) fn load_legacy_ruama(path: impl AsRef<Path>) -> std::io::Result<LoadedModel> {
106 load_bundle(path)
107}
108
109fn write_bytes(file: &mut std::fs::File, data: &[u8]) -> std::io::Result<()> {
110 file.write_all(&(data.len() as u32).to_le_bytes())?;
111 file.write_all(data)?;
112 Ok(())
113}
114
115fn read_bytes(file: &mut std::fs::File) -> std::io::Result<Vec<u8>> {
116 let mut len_buf = [0u8; 4];
117 file.read_exact(&mut len_buf)?;
118 let len = u32::from_le_bytes(len_buf) as usize;
119 let mut data = vec![0u8; len];
120 file.read_exact(&mut data)?;
121 Ok(data)
122}
123
124fn write_f64_slice(file: &mut std::fs::File, data: &[f64]) -> std::io::Result<()> {
125 file.write_all(&(data.len() as u32).to_le_bytes())?;
126 for &v in data {
127 file.write_all(&v.to_le_bytes())?;
128 }
129 Ok(())
130}
131
132fn read_f64_slice(file: &mut std::fs::File, expect: usize) -> std::io::Result<Vec<f64>> {
133 let mut len_buf = [0u8; 4];
134 file.read_exact(&mut len_buf)?;
135 let len = u32::from_le_bytes(len_buf) as usize;
136 if len != expect {
137 return Err(std::io::Error::new(
138 std::io::ErrorKind::InvalidData,
139 format!("expected {expect} norm values, got {len}"),
140 ));
141 }
142 let mut out = vec![0f64; len];
143 for slot in &mut out {
144 let mut b = [0u8; 8];
145 file.read_exact(&mut b)?;
146 *slot = f64::from_le_bytes(b);
147 }
148 Ok(out)
149}
150
151fn write_bundle(
152 path: impl AsRef<Path>,
153 w: &WeightStore,
154 meta: Option<ModelMetadata>,
155 norm: Option<&NormStats>,
156 config: Option<&UmapConfig>,
157) -> std::io::Result<()> {
158 let mut names: Vec<String> = w.0.keys().cloned().collect();
159 names.sort();
160 let mut file = std::fs::File::create(path.as_ref())?;
161 file.write_all(MAGIC)?;
162 let version = if config.is_some() {
163 VERSION_V4
164 } else if norm.is_some() {
165 VERSION_V3
166 } else if meta.is_some() {
167 VERSION_V2
168 } else {
169 VERSION_V1
170 };
171 file.write_all(&version.to_le_bytes())?;
172 if let Some(m) = meta {
173 file.write_all(&(m.n_train as u32).to_le_bytes())?;
174 file.write_all(&(m.n_features as u32).to_le_bytes())?;
175 file.write_all(&(m.n_pos as u32).to_le_bytes())?;
176 file.write_all(&(m.n_neg as u32).to_le_bytes())?;
177 }
178 if let Some(n) = norm {
179 write_f64_slice(&mut file, &n.mean)?;
180 write_f64_slice(&mut file, &n.std)?;
181 }
182 if let Some(cfg) = config {
183 let json = serde_json::to_vec(cfg)
184 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
185 write_bytes(&mut file, &json)?;
186 }
187 let count = names.len() as u32;
188 file.write_all(&count.to_le_bytes())?;
189 for name in &names {
190 let data = &w.0[name];
191 let name_bytes = name.as_bytes();
192 file.write_all(&(name_bytes.len() as u32).to_le_bytes())?;
193 file.write_all(name_bytes)?;
194 file.write_all(&(data.len() as u32).to_le_bytes())?;
195 for &v in data {
196 file.write_all(&v.to_le_bytes())?;
197 }
198 }
199 Ok(())
200}
201
202fn read_header(file: &mut std::fs::File) -> std::io::Result<u32> {
203 let mut magic = [0u8; 4];
204 file.read_exact(&mut magic)?;
205 if &magic != MAGIC {
206 return Err(std::io::Error::new(
207 std::io::ErrorKind::InvalidData,
208 "not an rlx-umap weight file (expected RUMA magic)",
209 ));
210 }
211 let mut word_buf = [0u8; 4];
212 file.read_exact(&mut word_buf)?;
213 Ok(u32::from_le_bytes(word_buf))
214}
215
216fn load_bundle(path: impl AsRef<Path>) -> std::io::Result<LoadedModel> {
217 let mut file = std::fs::File::open(path.as_ref())?;
218 let version = read_header(&mut file)?;
219
220 let (meta, norm, config, count) = match version {
221 VERSION_V4 => {
222 let m = read_meta(&mut file)?;
223 let mean = read_f64_slice(&mut file, m.n_features)?;
224 let std = read_f64_slice(&mut file, m.n_features)?;
225 let json = read_bytes(&mut file)?;
226 let cfg: UmapConfig = serde_json::from_slice(&json)
227 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
228 let count = read_count(&mut file)?;
229 (Some(m), Some(NormStats { mean, std }), Some(cfg), count)
230 }
231 VERSION_V3 => {
232 let m = read_meta(&mut file)?;
233 let mean = read_f64_slice(&mut file, m.n_features)?;
234 let std = read_f64_slice(&mut file, m.n_features)?;
235 let count = read_count(&mut file)?;
236 (Some(m), Some(NormStats { mean, std }), None, count)
237 }
238 VERSION_V2 => {
239 let m = read_meta(&mut file)?;
240 let count = read_count(&mut file)?;
241 (Some(m), None, None, count)
242 }
243 VERSION_V1 => {
244 let count = read_count(&mut file)?;
245 (None, None, None, count)
246 }
247 _ => {
248 let count = version as usize;
250 (None, None, None, count)
251 }
252 };
253
254 let meta = meta.ok_or_else(|| {
255 std::io::Error::new(
256 std::io::ErrorKind::InvalidData,
257 "file has weights only — use load_weights or re-save with save_model",
258 )
259 })?;
260
261 let norm = norm.unwrap_or_else(|| NormStats {
262 mean: vec![0.0; meta.n_features],
263 std: vec![1.0; meta.n_features],
264 });
265
266 let weights = read_tensors(&mut file, count)?;
267
268 Ok(LoadedModel {
269 weights,
270 meta,
271 norm,
272 config,
273 })
274}
275
276fn read_tensors(file: &mut std::fs::File, count: usize) -> std::io::Result<WeightStore> {
277 let mut weights = WeightStore::default();
278 for _ in 0..count {
279 let mut nlen_buf = [0u8; 4];
280 file.read_exact(&mut nlen_buf)?;
281 let nlen = u32::from_le_bytes(nlen_buf) as usize;
282 let mut name_bytes = vec![0u8; nlen];
283 file.read_exact(&mut name_bytes)?;
284 let name = String::from_utf8(name_bytes)
285 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
286 let mut dlen_buf = [0u8; 4];
287 file.read_exact(&mut dlen_buf)?;
288 let dlen = u32::from_le_bytes(dlen_buf) as usize;
289 let mut data = vec![0f32; dlen];
290 for slot in &mut data {
291 let mut b = [0u8; 4];
292 file.read_exact(&mut b)?;
293 *slot = f32::from_le_bytes(b);
294 }
295 weights.0.insert(name, data);
296 }
297 Ok(weights)
298}
299
300fn read_meta(file: &mut std::fs::File) -> std::io::Result<ModelMetadata> {
301 let mut buf = [0u8; 4];
302 file.read_exact(&mut buf)?;
303 let n_train = u32::from_le_bytes(buf) as usize;
304 file.read_exact(&mut buf)?;
305 let n_features = u32::from_le_bytes(buf) as usize;
306 file.read_exact(&mut buf)?;
307 let n_pos = u32::from_le_bytes(buf) as usize;
308 file.read_exact(&mut buf)?;
309 let n_neg = u32::from_le_bytes(buf) as usize;
310 Ok(ModelMetadata {
311 n_train,
312 n_features,
313 n_pos,
314 n_neg,
315 })
316}
317
318fn read_count(file: &mut std::fs::File) -> std::io::Result<usize> {
319 let mut count_buf = [0u8; 4];
320 file.read_exact(&mut count_buf)?;
321 Ok(u32::from_le_bytes(count_buf) as usize)
322}
323
324pub fn model_path(dir: impl AsRef<Path>, stem: &str) -> PathBuf {
326 crate::model_io::model_path(dir, stem)
327}