Skip to main content

entrenar/io/
save.rs

1//! Model saving functionality
2
3use super::format::{ModelFormat, SaveConfig};
4use super::model::Model;
5use crate::Tensor;
6use crate::{Error, Result};
7use safetensors::tensor::{Dtype, TensorView};
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::Write;
11use std::path::Path;
12
13/// Save a model to a file
14///
15/// # Arguments
16///
17/// * `model` - The model to save
18/// * `path` - Output file path
19/// * `config` - Save configuration (format, options)
20///
21/// # Example
22///
23/// ```no_run
24/// use entrenar::io::{Model, ModelMetadata, save_model, SaveConfig, ModelFormat};
25/// # use entrenar::Tensor;
26///
27/// let params = vec![
28///     ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
29/// ];
30/// let model = Model::new(ModelMetadata::new("my-model", "linear"), params);
31/// let config = SaveConfig::new(ModelFormat::Json);
32///
33/// save_model(&model, "model.json", &config).expect("failed to save model");
34/// ```
35pub fn save_model(model: &Model, path: impl AsRef<Path>, config: &SaveConfig) -> Result<()> {
36    let path = path.as_ref();
37
38    match config.format {
39        ModelFormat::SafeTensors => save_safetensors(model, path),
40        ModelFormat::Apr => save_apr(model, path),
41        ModelFormat::Json => save_json(model, path, config.pretty),
42        ModelFormat::Yaml => save_yaml(model, path),
43        #[cfg(feature = "gguf")]
44        ModelFormat::Gguf => Err(Error::Serialization(
45            "GGUF format not yet implemented. Enable 'gguf' feature and use realizar integration."
46                .to_string(),
47        )),
48    }
49}
50
51/// Serialize and save a model as JSON
52fn save_json(model: &Model, path: &Path, pretty: bool) -> Result<()> {
53    let state = model.to_state();
54    let data = if pretty {
55        serde_json::to_string_pretty(&state)
56            .map_err(|e| Error::Serialization(format!("JSON serialization failed: {e}")))?
57    } else {
58        serde_json::to_string(&state)
59            .map_err(|e| Error::Serialization(format!("JSON serialization failed: {e}")))?
60    };
61    let mut file = File::create(path)?;
62    file.write_all(data.as_bytes())?;
63    Ok(())
64}
65
66/// Serialize and save a model as YAML
67fn save_yaml(model: &Model, path: &Path) -> Result<()> {
68    let state = model.to_state();
69    let data = serde_yaml::to_string(&state)
70        .map_err(|e| Error::Serialization(format!("YAML serialization failed: {e}")))?;
71    let mut file = File::create(path)?;
72    file.write_all(data.as_bytes())?;
73    Ok(())
74}
75
76/// ALB-086: Infer tensor shapes using config-aware batch analysis.
77/// Scans all parameters to find hidden_size from norm weights, then
78/// computes proper 2D shapes for all weight matrices.
79pub(crate) fn infer_all_tensor_shapes(
80    parameters: &[(String, Tensor)],
81) -> HashMap<String, Vec<usize>> {
82    let mut shapes = HashMap::new();
83
84    // Find hidden_size from a norm weight (always 1D [H])
85    let hidden_size = parameters
86        .iter()
87        .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
88        .map_or(0, |(_, t)| t.len());
89
90    for (name, tensor) in parameters {
91        let numel = tensor.len();
92        let shape = if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
93            vec![numel]
94        } else if hidden_size > 0 && numel % hidden_size == 0 {
95            let other_dim = numel / hidden_size;
96            // For down_proj: [hidden_size, intermediate_size] — hidden is smaller dim
97            // For gate/up_proj: [intermediate_size, hidden_size] — hidden is smaller dim
98            // For q/o_proj: [hidden_size, hidden_size] — square
99            // For k/v_proj: [kv_dim, hidden_size] — kv_dim < hidden
100            // For embed/lm_head: [vocab_size, hidden_size]
101            if name.ends_with("down_proj.weight") {
102                vec![hidden_size, other_dim]
103            } else {
104                vec![other_dim, hidden_size]
105            }
106        } else {
107            vec![numel]
108        };
109        shapes.insert(name.clone(), shape);
110    }
111    shapes
112}
113
114/// Save model in SafeTensors format (HuggingFace compatible)
115fn save_safetensors(model: &Model, path: &Path) -> Result<()> {
116    // ALB-086: Compute proper 2D shapes for HuggingFace compatibility
117    let shapes = infer_all_tensor_shapes(&model.parameters);
118
119    // Collect tensor data with proper lifetime management
120    let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = model
121        .parameters
122        .iter()
123        .map(|(name, tensor)| {
124            let data = tensor.data();
125            let bytes: Vec<u8> =
126                bytemuck::cast_slice(data.as_slice().expect("tensor data must be contiguous"))
127                    .to_vec();
128            let shape = shapes.get(name).cloned().unwrap_or_else(|| vec![tensor.len()]);
129            (name.clone(), bytes, shape)
130        })
131        .collect();
132
133    // Create TensorViews from collected data
134    let views: Vec<(&str, TensorView<'_>)> = tensor_data
135        .iter()
136        .map(|(name, bytes, shape)| {
137            let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
138                .expect("TensorView construction must not fail for valid F32 data");
139            (name.as_str(), view)
140        })
141        .collect();
142
143    // Create metadata with model info
144    let mut metadata = HashMap::new();
145    metadata.insert("name".to_string(), model.metadata.name.clone());
146    metadata.insert("architecture".to_string(), model.metadata.architecture.clone());
147    metadata.insert("version".to_string(), model.metadata.version.clone());
148
149    // Serialize to SafeTensors format
150    let safetensor_bytes = safetensors::serialize(views, Some(metadata))
151        .map_err(|e| Error::Serialization(format!("SafeTensors serialization failed: {e}")))?;
152
153    // Write to file
154    std::fs::write(path, safetensor_bytes)?;
155
156    Ok(())
157}
158
159/// ALB-096: Save model in APR format (sovereign stack universal format).
160///
161/// Uses `AprWriter` for atomic single-file checkpoints with proper 2D shapes
162/// and model metadata. Supports both model weights and `__training__.*` tensors.
163fn save_apr(model: &Model, path: &Path) -> Result<()> {
164    use aprender::serialization::apr::AprWriter;
165    use serde_json::Value as JsonValue;
166
167    let mut writer = AprWriter::new();
168
169    // Embed model metadata (keys must match AprWriter's well-known key mapping)
170    writer.set_metadata("model_name", JsonValue::String(model.metadata.name.clone()));
171    writer.set_metadata("architecture", JsonValue::String(model.metadata.architecture.clone()));
172    writer.set_metadata("version", JsonValue::String(model.metadata.version.clone()));
173    writer.set_metadata("format", JsonValue::String("entrenar-checkpoint".into()));
174
175    // Compute proper 2D shapes (same logic as SafeTensors — ALB-086)
176    let shapes = infer_all_tensor_shapes(&model.parameters);
177
178    for (name, tensor) in &model.parameters {
179        let data = tensor.data();
180        let slice = data.as_slice().expect("tensor data must be contiguous");
181        let shape = shapes.get(name).cloned().unwrap_or_else(|| vec![tensor.len()]);
182        writer.add_tensor_f32(name, shape, slice);
183    }
184
185    writer
186        .write(path)
187        .map_err(|e| Error::Serialization(format!("APR serialization failed: {e}")))?;
188
189    Ok(())
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use crate::io::{Model, ModelMetadata};
196    use crate::Tensor;
197    use tempfile::NamedTempFile;
198
199    #[test]
200    fn test_save_model_json() {
201        let params = vec![
202            ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
203            ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
204        ];
205
206        let model = Model::new(ModelMetadata::new("test-model", "linear"), params);
207        let config = SaveConfig::new(ModelFormat::Json);
208
209        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
210        save_model(&model, temp_file.path(), &config).expect("save should succeed");
211
212        // Verify file was created and has content
213        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
214        assert!(!content.is_empty());
215        assert!(content.contains("test-model"));
216        assert!(content.contains("linear"));
217    }
218
219    #[test]
220    fn test_save_model_yaml() {
221        let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true))];
222
223        let model = Model::new(ModelMetadata::new("test", "simple"), params);
224        let config = SaveConfig::new(ModelFormat::Yaml);
225
226        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
227        save_model(&model, temp_file.path(), &config).expect("save should succeed");
228
229        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
230        assert!(content.contains("test"));
231        assert!(content.contains("simple"));
232    }
233
234    #[test]
235    fn test_save_model_json_pretty() {
236        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
237        let model = Model::new(ModelMetadata::new("pretty-test", "test"), params);
238        let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
239
240        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
241        save_model(&model, temp_file.path(), &config).expect("save should succeed");
242
243        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
244        // Pretty JSON should have newlines
245        assert!(content.contains('\n'));
246    }
247
248    #[test]
249    fn test_save_model_json_compact() {
250        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
251        let model = Model::new(ModelMetadata::new("compact-test", "test"), params);
252        let config = SaveConfig::new(ModelFormat::Json).with_pretty(false);
253
254        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
255        save_model(&model, temp_file.path(), &config).expect("save should succeed");
256
257        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
258        // Compact JSON should be single line (minus trailing)
259        let lines: Vec<&str> = content.lines().collect();
260        assert_eq!(lines.len(), 1);
261    }
262
263    #[test]
264    fn test_save_model_empty_params() {
265        let model = Model::new(ModelMetadata::new("empty", "test"), vec![]);
266        let config = SaveConfig::new(ModelFormat::Json);
267
268        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
269        save_model(&model, temp_file.path(), &config).expect("save should succeed");
270
271        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
272        assert!(content.contains("empty"));
273    }
274
275    #[test]
276    fn test_save_model_large_tensor() {
277        let large_data: Vec<f32> = (0..1000).map(|i| i as f32 * 0.001).collect();
278        let params = vec![("large".to_string(), Tensor::from_vec(large_data, false))];
279        let model = Model::new(ModelMetadata::new("large", "test"), params);
280        let config = SaveConfig::new(ModelFormat::Json);
281
282        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
283        save_model(&model, temp_file.path(), &config).expect("save should succeed");
284
285        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
286        assert!(content.len() > 1000);
287    }
288
289    #[test]
290    fn test_save_config_builder() {
291        let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
292        assert!(config.pretty);
293        assert_eq!(config.format, ModelFormat::Json);
294    }
295
296    #[test]
297    fn test_save_model_with_compress_option() {
298        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
299        let model = Model::new(ModelMetadata::new("compress-test", "test"), params);
300        let config = SaveConfig::new(ModelFormat::Json).with_compress(true);
301
302        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
303        // Currently compress is not implemented, but we can still save
304        save_model(&model, temp_file.path(), &config).expect("save should succeed");
305
306        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
307        assert!(content.contains("compress-test"));
308    }
309
310    #[test]
311    fn test_save_model_multiple_tensors() {
312        let params = vec![
313            ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
314            ("layer1.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
315            ("layer2.weight".to_string(), Tensor::from_vec(vec![3.0, 4.0], false)),
316        ];
317        let model = Model::new(ModelMetadata::new("multi", "deep"), params);
318        let config = SaveConfig::new(ModelFormat::Yaml);
319
320        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
321        save_model(&model, temp_file.path(), &config).expect("save should succeed");
322
323        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
324        assert!(content.contains("layer1.weight"));
325        assert!(content.contains("layer2.weight"));
326    }
327
328    #[test]
329    fn test_save_model_with_metadata() {
330        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
331        let meta = ModelMetadata::new("meta-test", "test")
332            .with_custom("version", serde_json::json!("1.0.0"))
333            .with_custom("author", serde_json::json!("test"));
334        let model = Model::new(meta, params);
335        let config = SaveConfig::new(ModelFormat::Json);
336
337        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
338        save_model(&model, temp_file.path(), &config).expect("save should succeed");
339
340        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
341        assert!(content.contains("version"));
342    }
343
344    #[test]
345    fn test_save_config_default() {
346        let config = SaveConfig::default();
347        assert_eq!(config.format, ModelFormat::Json);
348        assert!(config.pretty);
349        assert!(!config.compress);
350    }
351
352    #[test]
353    fn test_save_model_invalid_path() {
354        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
355        let model = Model::new(ModelMetadata::new("test", "test"), params);
356        let config = SaveConfig::new(ModelFormat::Json);
357
358        // Try to save to an invalid directory
359        let result = save_model(&model, "/nonexistent/directory/model.json", &config);
360        assert!(result.is_err());
361    }
362
363    #[test]
364    fn test_save_model_safetensors() {
365        let params = vec![
366            ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
367            ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
368        ];
369
370        let model = Model::new(ModelMetadata::new("safetensor-test", "linear"), params);
371        let config = SaveConfig::new(ModelFormat::SafeTensors);
372
373        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
374        save_model(&model, temp_file.path(), &config).expect("save should succeed");
375
376        // Verify file was created and is binary (starts with safetensors magic)
377        let content = std::fs::read(temp_file.path()).expect("file read should succeed");
378        assert!(!content.is_empty());
379        // SafeTensors files start with a header length (8 bytes)
380        assert!(content.len() > 8);
381    }
382
383    #[test]
384    fn test_save_model_safetensors_can_be_loaded() {
385        let params = vec![
386            ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
387            ("layer1.bias".to_string(), Tensor::from_vec(vec![0.5], false)),
388        ];
389
390        let model = Model::new(ModelMetadata::new("roundtrip-test", "mlp"), params);
391        let config = SaveConfig::new(ModelFormat::SafeTensors);
392
393        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
394        save_model(&model, temp_file.path(), &config).expect("save should succeed");
395
396        // Verify we can load it back with safetensors crate
397        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
398        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
399
400        // Check tensor names exist - names() returns Vec<&str>
401        let names = loaded.names();
402        assert!(names.contains(&"layer1.weight"));
403        assert!(names.contains(&"layer1.bias"));
404
405        // Check tensor data
406        let weight = loaded.tensor("layer1.weight").expect("load should succeed");
407        assert_eq!(weight.shape(), &[4]);
408        let weight_data: &[f32] = bytemuck::cast_slice(weight.data());
409        assert_eq!(weight_data, &[1.0, 2.0, 3.0, 4.0]);
410    }
411
412    #[test]
413    fn test_save_safetensors_metadata() {
414        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
415        let model = Model::new(ModelMetadata::new("meta-model", "transformer"), params);
416        let config = SaveConfig::new(ModelFormat::SafeTensors);
417
418        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
419        save_model(&model, temp_file.path(), &config).expect("save should succeed");
420
421        // Load and check metadata using read_metadata
422        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
423        let (_, st_metadata) =
424            safetensors::SafeTensors::read_metadata(&data).expect("deserialization should succeed");
425
426        let metadata = st_metadata.metadata();
427        assert!(metadata.is_some());
428        let meta = metadata.as_ref().expect("operation should succeed");
429        assert_eq!(meta.get("name").expect("key should exist"), "meta-model");
430        assert_eq!(meta.get("architecture").expect("key should exist"), "transformer");
431    }
432
433    #[test]
434    fn test_save_safetensors_large_tensor() {
435        let large_data: Vec<f32> = (0..10000).map(|i| i as f32 * 0.001).collect();
436        let params =
437            vec![("large_weights".to_string(), Tensor::from_vec(large_data.clone(), false))];
438        let model = Model::new(ModelMetadata::new("large", "test"), params);
439        let config = SaveConfig::new(ModelFormat::SafeTensors);
440
441        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
442        save_model(&model, temp_file.path(), &config).expect("save should succeed");
443
444        // Verify data integrity
445        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
446        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
447        let tensor = loaded.tensor("large_weights").expect("load should succeed");
448        let tensor_data: &[f32] = bytemuck::cast_slice(tensor.data());
449        assert_eq!(tensor_data.len(), 10000);
450        assert!((tensor_data[0] - 0.0).abs() < 1e-6);
451        assert!((tensor_data[9999] - 9.999).abs() < 1e-3);
452    }
453
454    #[test]
455    fn test_save_safetensors_invalid_path() {
456        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
457        let model = Model::new(ModelMetadata::new("test", "test"), params);
458        let config = SaveConfig::new(ModelFormat::SafeTensors);
459
460        let result = save_model(&model, "/nonexistent/directory/model.safetensors", &config);
461        assert!(result.is_err());
462    }
463
464    #[test]
465    fn test_save_safetensors_empty_params() {
466        let model = Model::new(ModelMetadata::new("empty", "test"), vec![]);
467        let config = SaveConfig::new(ModelFormat::SafeTensors);
468
469        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
470        save_model(&model, temp_file.path(), &config).expect("save should succeed");
471
472        // Should still create valid file with metadata
473        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
474        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
475        assert_eq!(loaded.len(), 0);
476    }
477
478    #[test]
479    fn test_save_safetensors_multiple_tensors() {
480        let params = vec![
481            ("encoder.layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
482            ("encoder.layer1.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
483            ("encoder.layer2.weight".to_string(), Tensor::from_vec(vec![3.0, 4.0, 5.0], false)),
484            ("decoder.layer1.weight".to_string(), Tensor::from_vec(vec![6.0, 7.0], false)),
485        ];
486        let model = Model::new(ModelMetadata::new("encoder-decoder", "transformer"), params);
487        let config = SaveConfig::new(ModelFormat::SafeTensors);
488
489        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
490        save_model(&model, temp_file.path(), &config).expect("save should succeed");
491
492        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
493        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
494        assert_eq!(loaded.len(), 4);
495
496        // names() returns Vec<&str> directly
497        let names = loaded.names();
498        assert!(names.contains(&"encoder.layer1.weight"));
499        assert!(names.contains(&"decoder.layer1.weight"));
500    }
501
502    /// ALB-086: Verify SafeTensors saves proper 2D shapes for LlamaForCausalLM weights.
503    #[test]
504    fn test_safetensors_saves_2d_shapes() {
505        let hidden = 64;
506        let intermediate = 128;
507        let vocab = 256;
508
509        let params = vec![
510            ("model.embed_tokens.weight".to_string(), Tensor::zeros(vocab * hidden, false)),
511            ("model.norm.weight".to_string(), Tensor::zeros(hidden, false)),
512            ("model.layers.0.input_layernorm.weight".to_string(), Tensor::zeros(hidden, false)),
513            (
514                "model.layers.0.post_attention_layernorm.weight".to_string(),
515                Tensor::zeros(hidden, false),
516            ),
517            (
518                "model.layers.0.self_attn.q_proj.weight".to_string(),
519                Tensor::zeros(hidden * hidden, false),
520            ),
521            (
522                "model.layers.0.self_attn.k_proj.weight".to_string(),
523                Tensor::zeros(16 * hidden, false),
524            ),
525            (
526                "model.layers.0.self_attn.v_proj.weight".to_string(),
527                Tensor::zeros(16 * hidden, false),
528            ),
529            (
530                "model.layers.0.self_attn.o_proj.weight".to_string(),
531                Tensor::zeros(hidden * hidden, false),
532            ),
533            (
534                "model.layers.0.mlp.gate_proj.weight".to_string(),
535                Tensor::zeros(intermediate * hidden, false),
536            ),
537            (
538                "model.layers.0.mlp.up_proj.weight".to_string(),
539                Tensor::zeros(intermediate * hidden, false),
540            ),
541            (
542                "model.layers.0.mlp.down_proj.weight".to_string(),
543                Tensor::zeros(hidden * intermediate, false),
544            ),
545        ];
546
547        let metadata = ModelMetadata::new("test", "LlamaForCausalLM");
548        let model = Model::new(metadata, params);
549        let config =
550            crate::io::format::SaveConfig::new(crate::io::format::ModelFormat::SafeTensors);
551        let temp = NamedTempFile::new().unwrap();
552        save_model(&model, temp.path(), &config).unwrap();
553
554        let data = std::fs::read(temp.path()).unwrap();
555        let loaded = safetensors::SafeTensors::deserialize(&data).unwrap();
556
557        // Norm weights should be 1D
558        assert_eq!(loaded.tensor("model.norm.weight").unwrap().shape(), &[hidden]);
559        assert_eq!(
560            loaded.tensor("model.layers.0.input_layernorm.weight").unwrap().shape(),
561            &[hidden]
562        );
563
564        // Projection weights should be 2D
565        assert_eq!(loaded.tensor("model.embed_tokens.weight").unwrap().shape(), &[vocab, hidden]);
566        assert_eq!(
567            loaded.tensor("model.layers.0.self_attn.q_proj.weight").unwrap().shape(),
568            &[hidden, hidden]
569        );
570        assert_eq!(
571            loaded.tensor("model.layers.0.self_attn.k_proj.weight").unwrap().shape(),
572            &[16, hidden]
573        );
574        assert_eq!(
575            loaded.tensor("model.layers.0.mlp.gate_proj.weight").unwrap().shape(),
576            &[intermediate, hidden]
577        );
578        assert_eq!(
579            loaded.tensor("model.layers.0.mlp.down_proj.weight").unwrap().shape(),
580            &[hidden, intermediate]
581        );
582    }
583}