Skip to main content

entrenar/io/
load.rs

1//! Model loading functionality
2//!
3//! Uses bytemuck for zero_copy tensor deserialization where possible,
4//! minimizing allocation_free overhead for large weight tensors.
5
6use super::format::ModelFormat;
7use super::model::{Model, ModelMetadata, ModelState};
8use crate::{Error, Result, Tensor};
9use std::fs::File;
10use std::io::Read;
11use std::path::Path;
12
13/// Load a model from a file
14///
15/// # Arguments
16///
17/// * `path` - Input file path
18///
19/// The format is automatically detected from the file extension.
20///
21/// # Example
22///
23/// ```no_run
24/// use entrenar::io::load_model;
25///
26/// let model = load_model("model.json").expect("failed to load model");
27/// println!("Loaded model: {}", model.metadata.name);
28/// ```
29pub fn load_model(path: impl AsRef<Path>) -> Result<Model> {
30    let path = path.as_ref();
31
32    // Detect format from extension
33    let ext = path
34        .extension()
35        .and_then(|s| s.to_str())
36        .ok_or_else(|| Error::Serialization("File has no extension".to_string()))?;
37
38    let format = ModelFormat::from_extension(ext)
39        .ok_or_else(|| Error::Serialization(format!("Unsupported file extension: {ext}")))?;
40
41    // Handle binary formats separately
42    if format == ModelFormat::SafeTensors {
43        return load_safetensors(path);
44    }
45    if format == ModelFormat::Apr {
46        return load_apr(path);
47    }
48    #[cfg(feature = "gguf")]
49    if format == ModelFormat::Gguf {
50        return load_gguf(path);
51    }
52
53    // Read file content (text formats)
54    let mut file = File::open(path)?;
55
56    let mut content = String::new();
57    file.read_to_string(&mut content)?;
58
59    // Deserialize based on format
60    let state: ModelState = match format {
61        ModelFormat::Json => serde_json::from_str(&content)
62            .map_err(|e| Error::Serialization(format!("JSON deserialization failed: {e}")))?,
63        ModelFormat::Yaml => serde_yaml::from_str(&content)
64            .map_err(|e| Error::Serialization(format!("YAML deserialization failed: {e}")))?,
65        ModelFormat::SafeTensors => unreachable!(), // Handled above
66        ModelFormat::Apr => unreachable!(),         // Handled above
67        #[cfg(feature = "gguf")]
68        ModelFormat::Gguf => unreachable!(), // Handled above
69    };
70
71    // Convert state to model
72    Ok(Model::from_state(state))
73}
74
75/// Load model from GGUF format via aprender's GgufReader.
76///
77/// UCBD §5: All model loading goes through the canonical stack.
78/// GGUF → aprender::format::gguf::GgufReader → dequantized f32 tensors → Model.
79#[cfg(feature = "gguf")]
80fn load_gguf(path: &Path) -> Result<Model> {
81    use aprender::format::gguf::GgufReader;
82
83    let reader = GgufReader::from_file(path)
84        .map_err(|e| Error::Serialization(format!("GGUF parsing failed: {e}")))?;
85
86    let arch = reader.architecture().unwrap_or_else(|| "unknown".to_string());
87    let name = reader.model_name().unwrap_or_else(|| {
88        path.file_stem().and_then(|s| s.to_str()).unwrap_or("gguf-model").to_string()
89    });
90
91    let metadata = ModelMetadata::new(name, arch);
92
93    // Extract all tensors as dequantized f32 (handles Q4_K, Q6_K, Q8_0, F16, etc.)
94    let all_tensors = reader
95        .get_all_tensors_f32()
96        .map_err(|e| Error::Serialization(format!("GGUF tensor extraction failed: {e}")))?;
97
98    let parameters: Vec<(String, Tensor)> = all_tensors
99        .into_iter()
100        .map(|(name, (data, _shape))| (name, Tensor::from_vec(data, false)))
101        .collect();
102
103    Ok(Model::new(metadata, parameters))
104}
105
106/// Load model from SafeTensors format (HuggingFace compatible)
107fn load_safetensors(path: &Path) -> Result<Model> {
108    // Read binary file
109    let data = std::fs::read(path)
110        .map_err(|e| Error::Serialization(format!("Failed to read file: {e}")))?;
111
112    // Parse SafeTensors and get metadata
113    let (_, st_metadata) = safetensors::SafeTensors::read_metadata(&data)
114        .map_err(|e| Error::Serialization(format!("SafeTensors parsing failed: {e}")))?;
115
116    // Extract custom metadata
117    let custom_meta = st_metadata.metadata();
118    let name = custom_meta
119        .as_ref()
120        .and_then(|m| m.get("name").cloned())
121        .unwrap_or_else(|| "unknown".to_string());
122    let architecture = custom_meta
123        .as_ref()
124        .and_then(|m| m.get("architecture").cloned())
125        .unwrap_or_else(|| "unknown".to_string());
126
127    let metadata = ModelMetadata::new(name, architecture);
128
129    // Deserialize to access tensors
130    let safetensors = safetensors::SafeTensors::deserialize(&data)
131        .map_err(|e| Error::Serialization(format!("SafeTensors parsing failed: {e}")))?;
132
133    // Convert tensors - names() returns Vec<&str>, not an iterator
134    let parameters: Vec<(String, Tensor)> = safetensors
135        .names()
136        .into_iter()
137        .map(|name| {
138            let tensor_view = safetensors
139                .tensor(name)
140                .expect("tensor name from names() must exist in SafeTensors");
141            let data: &[f32] = bytemuck::cast_slice(tensor_view.data());
142            let tensor = Tensor::from_vec(data.to_vec(), false); // Default to no grad
143            (name.to_string(), tensor)
144        })
145        .collect();
146
147    Ok(Model::new(metadata, parameters))
148}
149
150/// ALB-096: Load model from APR format (sovereign stack universal format).
151///
152/// Uses `AprReader` for atomic single-file checkpoints. Skips `__training__.*`
153/// tensors (optimizer state) — those are loaded separately by `CudaTransformerTrainer`.
154fn load_apr(path: &Path) -> Result<Model> {
155    use aprender::serialization::apr::AprReader;
156
157    let reader = AprReader::open(path)
158        .map_err(|e| Error::Serialization(format!("APR parsing failed: {e}")))?;
159
160    // Extract metadata (AprReader maps v2_meta.name → "model_name")
161    let name =
162        reader.get_metadata("model_name").and_then(|v| v.as_str()).unwrap_or("unknown").to_string();
163    let architecture = reader
164        .get_metadata("architecture")
165        .and_then(|v| v.as_str())
166        .unwrap_or("unknown")
167        .to_string();
168
169    let metadata = ModelMetadata::new(name, architecture);
170
171    // Load model weight tensors, skip __training__.* (optimizer state)
172    let parameters: Vec<(String, Tensor)> = reader
173        .tensors
174        .iter()
175        .filter(|td| !td.name.starts_with("__training__"))
176        .map(|td| {
177            let data = reader
178                .read_tensor_as_f32(&td.name)
179                .map_err(|e| Error::Serialization(format!("APR tensor read failed: {e}")))
180                .expect("tensor listed in descriptors must be readable");
181            (td.name.clone(), Tensor::from_vec(data, false))
182        })
183        .collect();
184
185    Ok(Model::new(metadata, parameters))
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191    use crate::io::{save_model, Model, ModelMetadata, SaveConfig};
192    use crate::Tensor;
193    use tempfile::NamedTempFile;
194
195    #[test]
196    fn test_load_model_json() {
197        // Create and save a model
198        let params = vec![
199            ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
200            ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
201        ];
202
203        let original = Model::new(ModelMetadata::new("test-model", "linear"), params);
204
205        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
206        let temp_path = temp_file.path().with_extension("json");
207
208        let config = SaveConfig::new(ModelFormat::Json);
209        save_model(&original, &temp_path, &config).expect("save should succeed");
210
211        // Load it back
212        let loaded = load_model(&temp_path).expect("load should succeed");
213
214        // Verify
215        assert_eq!(original.metadata.name, loaded.metadata.name);
216        assert_eq!(original.metadata.architecture, loaded.metadata.architecture);
217        assert_eq!(original.parameters.len(), loaded.parameters.len());
218
219        // Clean up
220        std::fs::remove_file(temp_path).ok();
221    }
222
223    #[test]
224    fn test_load_model_yaml() {
225        let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true))];
226
227        let original = Model::new(ModelMetadata::new("yaml-test", "simple"), params);
228
229        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
230        let temp_path = temp_file.path().with_extension("yaml");
231
232        let config = SaveConfig::new(ModelFormat::Yaml);
233        save_model(&original, &temp_path, &config).expect("save should succeed");
234
235        let loaded = load_model(&temp_path).expect("load should succeed");
236
237        assert_eq!(original.metadata.name, loaded.metadata.name);
238        assert_eq!(original.parameters.len(), loaded.parameters.len());
239
240        // Clean up
241        std::fs::remove_file(temp_path).ok();
242    }
243
244    #[test]
245    fn test_load_unsupported_extension() {
246        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
247        let temp_path = temp_file.path().with_extension("unknown");
248
249        let result = load_model(&temp_path);
250        assert!(result.is_err());
251    }
252
253    #[test]
254    fn test_save_load_round_trip() {
255        // Create a model with multiple parameters
256        let params = vec![
257            ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
258            ("layer1.bias".to_string(), Tensor::from_vec(vec![0.1, 0.2], true)),
259            ("layer2.weight".to_string(), Tensor::from_vec(vec![5.0, 6.0], false)),
260        ];
261
262        let meta = ModelMetadata::new("round-trip-test", "multi-layer")
263            .with_custom("layers", serde_json::json!(2))
264            .with_custom("hidden_size", serde_json::json!(4));
265
266        let original = Model::new(meta, params);
267
268        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
269        let temp_path = temp_file.path().with_extension("json");
270
271        // Save and load
272        let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
273        save_model(&original, &temp_path, &config).expect("save should succeed");
274        let loaded = load_model(&temp_path).expect("load should succeed");
275
276        // Verify all parameters match
277        assert_eq!(original.parameters.len(), loaded.parameters.len());
278
279        for (orig_name, orig_tensor) in &original.parameters {
280            let loaded_tensor = loaded.get_parameter(orig_name).expect("load should succeed");
281            assert_eq!(orig_tensor.data(), loaded_tensor.data());
282            assert_eq!(orig_tensor.requires_grad(), loaded_tensor.requires_grad());
283        }
284
285        // Verify metadata
286        assert_eq!(original.metadata.custom.len(), loaded.metadata.custom.len());
287
288        // Clean up
289        std::fs::remove_file(temp_path).ok();
290    }
291
292    #[test]
293    fn test_load_model_file_not_found() {
294        let result = load_model("nonexistent_file.json");
295        assert!(result.is_err());
296    }
297
298    #[test]
299    fn test_load_model_no_extension() {
300        let result = load_model("model_without_extension");
301        assert!(result.is_err());
302        // Use match instead of unwrap_err since Model doesn't implement Debug
303        if let Err(err) = result {
304            assert!(err.to_string().contains("no extension"));
305        }
306    }
307
308    #[test]
309    fn test_load_model_invalid_json() {
310        use std::io::Write;
311        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
312        let temp_path = temp_file.path().with_extension("json");
313
314        // Write invalid JSON
315        let mut f = File::create(&temp_path).expect("file write should succeed");
316        f.write_all(b"{ invalid json }").expect("file write should succeed");
317        drop(f);
318
319        let result = load_model(&temp_path);
320        assert!(result.is_err());
321
322        std::fs::remove_file(temp_path).ok();
323    }
324
325    #[test]
326    fn test_load_model_invalid_yaml() {
327        use std::io::Write;
328        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
329        let temp_path = temp_file.path().with_extension("yaml");
330
331        // Write invalid YAML
332        let mut f = File::create(&temp_path).expect("file write should succeed");
333        f.write_all(b"this: is: not: valid: yaml: [}").expect("file write should succeed");
334        drop(f);
335
336        let result = load_model(&temp_path);
337        assert!(result.is_err());
338
339        std::fs::remove_file(temp_path).ok();
340    }
341
342    #[test]
343    fn test_load_yml_extension() {
344        let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0], true))];
345        let original = Model::new(ModelMetadata::new("yml-test", "simple"), params);
346
347        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
348        let temp_path = temp_file.path().with_extension("yml");
349
350        let config = SaveConfig::new(ModelFormat::Yaml);
351        save_model(&original, &temp_path, &config).expect("save should succeed");
352
353        let loaded = load_model(&temp_path).expect("load should succeed");
354        assert_eq!(original.metadata.name, loaded.metadata.name);
355
356        std::fs::remove_file(temp_path).ok();
357    }
358
359    #[test]
360    fn test_load_model_safetensors() {
361        let params = vec![
362            ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
363            ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
364        ];
365
366        let original = Model::new(ModelMetadata::new("safetensor-test", "linear"), params);
367
368        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
369        let temp_path = temp_file.path().with_extension("safetensors");
370
371        let config = SaveConfig::new(ModelFormat::SafeTensors);
372        save_model(&original, &temp_path, &config).expect("save should succeed");
373
374        let loaded = load_model(&temp_path).expect("load should succeed");
375
376        assert_eq!(original.metadata.name, loaded.metadata.name);
377        assert_eq!(original.metadata.architecture, loaded.metadata.architecture);
378        assert_eq!(original.parameters.len(), loaded.parameters.len());
379
380        std::fs::remove_file(temp_path).ok();
381    }
382
383    #[test]
384    fn test_safetensors_round_trip_data_integrity() {
385        let params = vec![
386            ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
387            ("layer1.bias".to_string(), Tensor::from_vec(vec![0.5, 0.6], false)),
388        ];
389
390        let original = Model::new(ModelMetadata::new("round-trip", "mlp"), params);
391
392        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
393        let temp_path = temp_file.path().with_extension("safetensors");
394
395        let config = SaveConfig::new(ModelFormat::SafeTensors);
396        save_model(&original, &temp_path, &config).expect("save should succeed");
397
398        let loaded = load_model(&temp_path).expect("load should succeed");
399
400        // Verify data matches
401        for (name, orig_tensor) in &original.parameters {
402            let loaded_tensor = loaded.get_parameter(name).expect("load should succeed");
403            assert_eq!(orig_tensor.data(), loaded_tensor.data());
404        }
405
406        std::fs::remove_file(temp_path).ok();
407    }
408
409    #[test]
410    fn test_load_safetensors_file_not_found() {
411        let result = load_model("nonexistent.safetensors");
412        assert!(result.is_err());
413    }
414
415    #[test]
416    fn test_load_safetensors_invalid_data() {
417        use std::io::Write;
418        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
419        let temp_path = temp_file.path().with_extension("safetensors");
420
421        // Write invalid safetensors data
422        let mut f = File::create(&temp_path).expect("file write should succeed");
423        f.write_all(b"not valid safetensors binary data").expect("file write should succeed");
424        drop(f);
425
426        let result = load_model(&temp_path);
427        assert!(result.is_err());
428
429        std::fs::remove_file(temp_path).ok();
430    }
431
432    #[test]
433    fn test_load_safetensors_large_model() {
434        let large_data: Vec<f32> = (0..5000).map(|i| i as f32 * 0.001).collect();
435        let params = vec![
436            ("large_weight".to_string(), Tensor::from_vec(large_data.clone(), false)),
437            ("small_bias".to_string(), Tensor::from_vec(vec![0.1, 0.2], false)),
438        ];
439
440        let original = Model::new(ModelMetadata::new("large-model", "test"), params);
441
442        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
443        let temp_path = temp_file.path().with_extension("safetensors");
444
445        let config = SaveConfig::new(ModelFormat::SafeTensors);
446        save_model(&original, &temp_path, &config).expect("save should succeed");
447
448        let loaded = load_model(&temp_path).expect("load should succeed");
449
450        let loaded_large = loaded.get_parameter("large_weight").expect("load should succeed");
451        assert_eq!(loaded_large.len(), 5000);
452
453        // Verify some values
454        let data = loaded_large.data();
455        assert!((data[[0]] - 0.0).abs() < 1e-6);
456        assert!((data[[4999]] - 4.999).abs() < 1e-3);
457
458        std::fs::remove_file(temp_path).ok();
459    }
460
461    #[test]
462    fn test_load_safetensors_metadata_preserved() {
463        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
464        let original = Model::new(ModelMetadata::new("meta-model", "transformer"), params);
465
466        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
467        let temp_path = temp_file.path().with_extension("safetensors");
468
469        let config = SaveConfig::new(ModelFormat::SafeTensors);
470        save_model(&original, &temp_path, &config).expect("save should succeed");
471
472        let loaded = load_model(&temp_path).expect("load should succeed");
473
474        assert_eq!(loaded.metadata.name, "meta-model");
475        assert_eq!(loaded.metadata.architecture, "transformer");
476
477        std::fs::remove_file(temp_path).ok();
478    }
479
480    /// Model loading performance: load_bench / loading_time (PW-07)
481    #[test]
482    fn load_bench_loading_time() {
483        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0; 1000], false))];
484        let original = Model::new(ModelMetadata::new("bench-model", "test"), params);
485
486        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
487        let temp_path = temp_file.path().with_extension("safetensors");
488
489        let config = SaveConfig::new(ModelFormat::SafeTensors);
490        save_model(&original, &temp_path, &config).expect("save should succeed");
491
492        let start = std::time::Instant::now();
493        let _loaded = load_model(&temp_path).expect("load should succeed");
494        let loading_time = start.elapsed();
495
496        // Model loading should complete in reasonable time
497        assert!(loading_time.as_millis() < 5000, "load_bench: {loading_time:?}");
498
499        std::fs::remove_file(temp_path).ok();
500    }
501
502    #[test]
503    fn test_apr_round_trip() {
504        let params = vec![
505            ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
506            ("layer1.bias".to_string(), Tensor::from_vec(vec![0.5, 0.6], false)),
507        ];
508
509        let original = Model::new(ModelMetadata::new("apr-test", "transformer"), params);
510
511        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
512        let temp_path = temp_file.path().with_extension("apr");
513
514        let config = SaveConfig::new(ModelFormat::Apr);
515        save_model(&original, &temp_path, &config).expect("APR save should succeed");
516
517        let loaded = load_model(&temp_path).expect("APR load should succeed");
518
519        assert_eq!(loaded.metadata.name, "apr-test");
520        assert_eq!(loaded.metadata.architecture, "transformer");
521        assert_eq!(loaded.parameters.len(), 2);
522
523        for (name, orig_tensor) in &original.parameters {
524            let loaded_tensor = loaded.get_parameter(name).expect("tensor should exist");
525            assert_eq!(orig_tensor.data(), loaded_tensor.data());
526        }
527
528        std::fs::remove_file(temp_path).ok();
529    }
530}