1use super::format::ModelFormat;
7use super::model::{Model, ModelMetadata, ModelState};
8use crate::{Error, Result, Tensor};
9use std::fs::File;
10use std::io::Read;
11use std::path::Path;
12
13pub fn load_model(path: impl AsRef<Path>) -> Result<Model> {
30 let path = path.as_ref();
31
32 let ext = path
34 .extension()
35 .and_then(|s| s.to_str())
36 .ok_or_else(|| Error::Serialization("File has no extension".to_string()))?;
37
38 let format = ModelFormat::from_extension(ext)
39 .ok_or_else(|| Error::Serialization(format!("Unsupported file extension: {ext}")))?;
40
41 if format == ModelFormat::SafeTensors {
43 return load_safetensors(path);
44 }
45 if format == ModelFormat::Apr {
46 return load_apr(path);
47 }
48 #[cfg(feature = "gguf")]
49 if format == ModelFormat::Gguf {
50 return load_gguf(path);
51 }
52
53 let mut file = File::open(path)?;
55
56 let mut content = String::new();
57 file.read_to_string(&mut content)?;
58
59 let state: ModelState = match format {
61 ModelFormat::Json => serde_json::from_str(&content)
62 .map_err(|e| Error::Serialization(format!("JSON deserialization failed: {e}")))?,
63 ModelFormat::Yaml => serde_yaml::from_str(&content)
64 .map_err(|e| Error::Serialization(format!("YAML deserialization failed: {e}")))?,
65 ModelFormat::SafeTensors => unreachable!(), ModelFormat::Apr => unreachable!(), #[cfg(feature = "gguf")]
68 ModelFormat::Gguf => unreachable!(), };
70
71 Ok(Model::from_state(state))
73}
74
75#[cfg(feature = "gguf")]
80fn load_gguf(path: &Path) -> Result<Model> {
81 use aprender::format::gguf::GgufReader;
82
83 let reader = GgufReader::from_file(path)
84 .map_err(|e| Error::Serialization(format!("GGUF parsing failed: {e}")))?;
85
86 let arch = reader.architecture().unwrap_or_else(|| "unknown".to_string());
87 let name = reader.model_name().unwrap_or_else(|| {
88 path.file_stem().and_then(|s| s.to_str()).unwrap_or("gguf-model").to_string()
89 });
90
91 let metadata = ModelMetadata::new(name, arch);
92
93 let all_tensors = reader
95 .get_all_tensors_f32()
96 .map_err(|e| Error::Serialization(format!("GGUF tensor extraction failed: {e}")))?;
97
98 let parameters: Vec<(String, Tensor)> = all_tensors
99 .into_iter()
100 .map(|(name, (data, _shape))| (name, Tensor::from_vec(data, false)))
101 .collect();
102
103 Ok(Model::new(metadata, parameters))
104}
105
106fn load_safetensors(path: &Path) -> Result<Model> {
108 let data = std::fs::read(path)
110 .map_err(|e| Error::Serialization(format!("Failed to read file: {e}")))?;
111
112 let (_, st_metadata) = safetensors::SafeTensors::read_metadata(&data)
114 .map_err(|e| Error::Serialization(format!("SafeTensors parsing failed: {e}")))?;
115
116 let custom_meta = st_metadata.metadata();
118 let name = custom_meta
119 .as_ref()
120 .and_then(|m| m.get("name").cloned())
121 .unwrap_or_else(|| "unknown".to_string());
122 let architecture = custom_meta
123 .as_ref()
124 .and_then(|m| m.get("architecture").cloned())
125 .unwrap_or_else(|| "unknown".to_string());
126
127 let metadata = ModelMetadata::new(name, architecture);
128
129 let safetensors = safetensors::SafeTensors::deserialize(&data)
131 .map_err(|e| Error::Serialization(format!("SafeTensors parsing failed: {e}")))?;
132
133 let parameters: Vec<(String, Tensor)> = safetensors
135 .names()
136 .into_iter()
137 .map(|name| {
138 let tensor_view = safetensors
139 .tensor(name)
140 .expect("tensor name from names() must exist in SafeTensors");
141 let data: &[f32] = bytemuck::cast_slice(tensor_view.data());
142 let tensor = Tensor::from_vec(data.to_vec(), false); (name.to_string(), tensor)
144 })
145 .collect();
146
147 Ok(Model::new(metadata, parameters))
148}
149
150fn load_apr(path: &Path) -> Result<Model> {
155 use aprender::serialization::apr::AprReader;
156
157 let reader = AprReader::open(path)
158 .map_err(|e| Error::Serialization(format!("APR parsing failed: {e}")))?;
159
160 let name =
162 reader.get_metadata("model_name").and_then(|v| v.as_str()).unwrap_or("unknown").to_string();
163 let architecture = reader
164 .get_metadata("architecture")
165 .and_then(|v| v.as_str())
166 .unwrap_or("unknown")
167 .to_string();
168
169 let metadata = ModelMetadata::new(name, architecture);
170
171 let parameters: Vec<(String, Tensor)> = reader
173 .tensors
174 .iter()
175 .filter(|td| !td.name.starts_with("__training__"))
176 .map(|td| {
177 let data = reader
178 .read_tensor_as_f32(&td.name)
179 .map_err(|e| Error::Serialization(format!("APR tensor read failed: {e}")))
180 .expect("tensor listed in descriptors must be readable");
181 (td.name.clone(), Tensor::from_vec(data, false))
182 })
183 .collect();
184
185 Ok(Model::new(metadata, parameters))
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::io::{save_model, Model, ModelMetadata, SaveConfig};
192 use crate::Tensor;
193 use tempfile::NamedTempFile;
194
195 #[test]
196 fn test_load_model_json() {
197 let params = vec![
199 ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
200 ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
201 ];
202
203 let original = Model::new(ModelMetadata::new("test-model", "linear"), params);
204
205 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
206 let temp_path = temp_file.path().with_extension("json");
207
208 let config = SaveConfig::new(ModelFormat::Json);
209 save_model(&original, &temp_path, &config).expect("save should succeed");
210
211 let loaded = load_model(&temp_path).expect("load should succeed");
213
214 assert_eq!(original.metadata.name, loaded.metadata.name);
216 assert_eq!(original.metadata.architecture, loaded.metadata.architecture);
217 assert_eq!(original.parameters.len(), loaded.parameters.len());
218
219 std::fs::remove_file(temp_path).ok();
221 }
222
223 #[test]
224 fn test_load_model_yaml() {
225 let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true))];
226
227 let original = Model::new(ModelMetadata::new("yaml-test", "simple"), params);
228
229 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
230 let temp_path = temp_file.path().with_extension("yaml");
231
232 let config = SaveConfig::new(ModelFormat::Yaml);
233 save_model(&original, &temp_path, &config).expect("save should succeed");
234
235 let loaded = load_model(&temp_path).expect("load should succeed");
236
237 assert_eq!(original.metadata.name, loaded.metadata.name);
238 assert_eq!(original.parameters.len(), loaded.parameters.len());
239
240 std::fs::remove_file(temp_path).ok();
242 }
243
244 #[test]
245 fn test_load_unsupported_extension() {
246 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
247 let temp_path = temp_file.path().with_extension("unknown");
248
249 let result = load_model(&temp_path);
250 assert!(result.is_err());
251 }
252
253 #[test]
254 fn test_save_load_round_trip() {
255 let params = vec![
257 ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
258 ("layer1.bias".to_string(), Tensor::from_vec(vec![0.1, 0.2], true)),
259 ("layer2.weight".to_string(), Tensor::from_vec(vec![5.0, 6.0], false)),
260 ];
261
262 let meta = ModelMetadata::new("round-trip-test", "multi-layer")
263 .with_custom("layers", serde_json::json!(2))
264 .with_custom("hidden_size", serde_json::json!(4));
265
266 let original = Model::new(meta, params);
267
268 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
269 let temp_path = temp_file.path().with_extension("json");
270
271 let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
273 save_model(&original, &temp_path, &config).expect("save should succeed");
274 let loaded = load_model(&temp_path).expect("load should succeed");
275
276 assert_eq!(original.parameters.len(), loaded.parameters.len());
278
279 for (orig_name, orig_tensor) in &original.parameters {
280 let loaded_tensor = loaded.get_parameter(orig_name).expect("load should succeed");
281 assert_eq!(orig_tensor.data(), loaded_tensor.data());
282 assert_eq!(orig_tensor.requires_grad(), loaded_tensor.requires_grad());
283 }
284
285 assert_eq!(original.metadata.custom.len(), loaded.metadata.custom.len());
287
288 std::fs::remove_file(temp_path).ok();
290 }
291
292 #[test]
293 fn test_load_model_file_not_found() {
294 let result = load_model("nonexistent_file.json");
295 assert!(result.is_err());
296 }
297
298 #[test]
299 fn test_load_model_no_extension() {
300 let result = load_model("model_without_extension");
301 assert!(result.is_err());
302 if let Err(err) = result {
304 assert!(err.to_string().contains("no extension"));
305 }
306 }
307
308 #[test]
309 fn test_load_model_invalid_json() {
310 use std::io::Write;
311 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
312 let temp_path = temp_file.path().with_extension("json");
313
314 let mut f = File::create(&temp_path).expect("file write should succeed");
316 f.write_all(b"{ invalid json }").expect("file write should succeed");
317 drop(f);
318
319 let result = load_model(&temp_path);
320 assert!(result.is_err());
321
322 std::fs::remove_file(temp_path).ok();
323 }
324
325 #[test]
326 fn test_load_model_invalid_yaml() {
327 use std::io::Write;
328 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
329 let temp_path = temp_file.path().with_extension("yaml");
330
331 let mut f = File::create(&temp_path).expect("file write should succeed");
333 f.write_all(b"this: is: not: valid: yaml: [}").expect("file write should succeed");
334 drop(f);
335
336 let result = load_model(&temp_path);
337 assert!(result.is_err());
338
339 std::fs::remove_file(temp_path).ok();
340 }
341
342 #[test]
343 fn test_load_yml_extension() {
344 let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0], true))];
345 let original = Model::new(ModelMetadata::new("yml-test", "simple"), params);
346
347 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
348 let temp_path = temp_file.path().with_extension("yml");
349
350 let config = SaveConfig::new(ModelFormat::Yaml);
351 save_model(&original, &temp_path, &config).expect("save should succeed");
352
353 let loaded = load_model(&temp_path).expect("load should succeed");
354 assert_eq!(original.metadata.name, loaded.metadata.name);
355
356 std::fs::remove_file(temp_path).ok();
357 }
358
359 #[test]
360 fn test_load_model_safetensors() {
361 let params = vec![
362 ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
363 ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
364 ];
365
366 let original = Model::new(ModelMetadata::new("safetensor-test", "linear"), params);
367
368 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
369 let temp_path = temp_file.path().with_extension("safetensors");
370
371 let config = SaveConfig::new(ModelFormat::SafeTensors);
372 save_model(&original, &temp_path, &config).expect("save should succeed");
373
374 let loaded = load_model(&temp_path).expect("load should succeed");
375
376 assert_eq!(original.metadata.name, loaded.metadata.name);
377 assert_eq!(original.metadata.architecture, loaded.metadata.architecture);
378 assert_eq!(original.parameters.len(), loaded.parameters.len());
379
380 std::fs::remove_file(temp_path).ok();
381 }
382
383 #[test]
384 fn test_safetensors_round_trip_data_integrity() {
385 let params = vec![
386 ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
387 ("layer1.bias".to_string(), Tensor::from_vec(vec![0.5, 0.6], false)),
388 ];
389
390 let original = Model::new(ModelMetadata::new("round-trip", "mlp"), params);
391
392 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
393 let temp_path = temp_file.path().with_extension("safetensors");
394
395 let config = SaveConfig::new(ModelFormat::SafeTensors);
396 save_model(&original, &temp_path, &config).expect("save should succeed");
397
398 let loaded = load_model(&temp_path).expect("load should succeed");
399
400 for (name, orig_tensor) in &original.parameters {
402 let loaded_tensor = loaded.get_parameter(name).expect("load should succeed");
403 assert_eq!(orig_tensor.data(), loaded_tensor.data());
404 }
405
406 std::fs::remove_file(temp_path).ok();
407 }
408
409 #[test]
410 fn test_load_safetensors_file_not_found() {
411 let result = load_model("nonexistent.safetensors");
412 assert!(result.is_err());
413 }
414
415 #[test]
416 fn test_load_safetensors_invalid_data() {
417 use std::io::Write;
418 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
419 let temp_path = temp_file.path().with_extension("safetensors");
420
421 let mut f = File::create(&temp_path).expect("file write should succeed");
423 f.write_all(b"not valid safetensors binary data").expect("file write should succeed");
424 drop(f);
425
426 let result = load_model(&temp_path);
427 assert!(result.is_err());
428
429 std::fs::remove_file(temp_path).ok();
430 }
431
432 #[test]
433 fn test_load_safetensors_large_model() {
434 let large_data: Vec<f32> = (0..5000).map(|i| i as f32 * 0.001).collect();
435 let params = vec![
436 ("large_weight".to_string(), Tensor::from_vec(large_data.clone(), false)),
437 ("small_bias".to_string(), Tensor::from_vec(vec![0.1, 0.2], false)),
438 ];
439
440 let original = Model::new(ModelMetadata::new("large-model", "test"), params);
441
442 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
443 let temp_path = temp_file.path().with_extension("safetensors");
444
445 let config = SaveConfig::new(ModelFormat::SafeTensors);
446 save_model(&original, &temp_path, &config).expect("save should succeed");
447
448 let loaded = load_model(&temp_path).expect("load should succeed");
449
450 let loaded_large = loaded.get_parameter("large_weight").expect("load should succeed");
451 assert_eq!(loaded_large.len(), 5000);
452
453 let data = loaded_large.data();
455 assert!((data[[0]] - 0.0).abs() < 1e-6);
456 assert!((data[[4999]] - 4.999).abs() < 1e-3);
457
458 std::fs::remove_file(temp_path).ok();
459 }
460
461 #[test]
462 fn test_load_safetensors_metadata_preserved() {
463 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
464 let original = Model::new(ModelMetadata::new("meta-model", "transformer"), params);
465
466 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
467 let temp_path = temp_file.path().with_extension("safetensors");
468
469 let config = SaveConfig::new(ModelFormat::SafeTensors);
470 save_model(&original, &temp_path, &config).expect("save should succeed");
471
472 let loaded = load_model(&temp_path).expect("load should succeed");
473
474 assert_eq!(loaded.metadata.name, "meta-model");
475 assert_eq!(loaded.metadata.architecture, "transformer");
476
477 std::fs::remove_file(temp_path).ok();
478 }
479
480 #[test]
482 fn load_bench_loading_time() {
483 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0; 1000], false))];
484 let original = Model::new(ModelMetadata::new("bench-model", "test"), params);
485
486 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
487 let temp_path = temp_file.path().with_extension("safetensors");
488
489 let config = SaveConfig::new(ModelFormat::SafeTensors);
490 save_model(&original, &temp_path, &config).expect("save should succeed");
491
492 let start = std::time::Instant::now();
493 let _loaded = load_model(&temp_path).expect("load should succeed");
494 let loading_time = start.elapsed();
495
496 assert!(loading_time.as_millis() < 5000, "load_bench: {loading_time:?}");
498
499 std::fs::remove_file(temp_path).ok();
500 }
501
502 #[test]
503 fn test_apr_round_trip() {
504 let params = vec![
505 ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
506 ("layer1.bias".to_string(), Tensor::from_vec(vec![0.5, 0.6], false)),
507 ];
508
509 let original = Model::new(ModelMetadata::new("apr-test", "transformer"), params);
510
511 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
512 let temp_path = temp_file.path().with_extension("apr");
513
514 let config = SaveConfig::new(ModelFormat::Apr);
515 save_model(&original, &temp_path, &config).expect("APR save should succeed");
516
517 let loaded = load_model(&temp_path).expect("APR load should succeed");
518
519 assert_eq!(loaded.metadata.name, "apr-test");
520 assert_eq!(loaded.metadata.architecture, "transformer");
521 assert_eq!(loaded.parameters.len(), 2);
522
523 for (name, orig_tensor) in &original.parameters {
524 let loaded_tensor = loaded.get_parameter(name).expect("tensor should exist");
525 assert_eq!(orig_tensor.data(), loaded_tensor.data());
526 }
527
528 std::fs::remove_file(temp_path).ok();
529 }
530}