Skip to main content

entrenar/io/
save.rs

1//! Model saving functionality
2
3use super::format::{ModelFormat, SaveConfig};
4use super::model::Model;
5use crate::Tensor;
6use crate::{Error, Result};
7use safetensors::tensor::{Dtype, TensorView};
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::Write;
11use std::path::Path;
12
13/// Save a model to a file
14///
15/// # Arguments
16///
17/// * `model` - The model to save
18/// * `path` - Output file path
19/// * `config` - Save configuration (format, options)
20///
21/// # Example
22///
23/// ```no_run
24/// use entrenar::io::{Model, ModelMetadata, save_model, SaveConfig, ModelFormat};
25/// # use entrenar::Tensor;
26///
27/// let params = vec![
28///     ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
29/// ];
30/// let model = Model::new(ModelMetadata::new("my-model", "linear"), params);
31/// let config = SaveConfig::new(ModelFormat::Json);
32///
33/// save_model(&model, "model.json", &config).expect("failed to save model");
34/// ```
35pub fn save_model(model: &Model, path: impl AsRef<Path>, config: &SaveConfig) -> Result<()> {
36    let path = path.as_ref();
37
38    match config.format {
39        ModelFormat::SafeTensors => save_safetensors(model, path),
40        ModelFormat::Apr => save_apr(model, path),
41        ModelFormat::Json => save_json(model, path, config.pretty),
42        ModelFormat::Yaml => save_yaml(model, path),
43        #[cfg(feature = "gguf")]
44        ModelFormat::Gguf => Err(Error::Serialization(
45            "GGUF format not yet implemented. Enable 'gguf' feature and use realizar integration."
46                .to_string(),
47        )),
48    }
49}
50
51/// Serialize and save a model as JSON
52fn save_json(model: &Model, path: &Path, pretty: bool) -> Result<()> {
53    let state = model.to_state();
54    let data = if pretty {
55        serde_json::to_string_pretty(&state)
56            .map_err(|e| Error::Serialization(format!("JSON serialization failed: {e}")))?
57    } else {
58        serde_json::to_string(&state)
59            .map_err(|e| Error::Serialization(format!("JSON serialization failed: {e}")))?
60    };
61    let mut file = File::create(path)?;
62    file.write_all(data.as_bytes())?;
63    Ok(())
64}
65
66/// Serialize and save a model as YAML
67fn save_yaml(model: &Model, path: &Path) -> Result<()> {
68    let state = model.to_state();
69    let data = serde_yaml::to_string(&state)
70        .map_err(|e| Error::Serialization(format!("YAML serialization failed: {e}")))?;
71    let mut file = File::create(path)?;
72    file.write_all(data.as_bytes())?;
73    Ok(())
74}
75
76/// ALB-086: Infer tensor shapes using config-aware batch analysis.
77/// Scans all parameters to find hidden_size from norm weights, then
78/// computes proper 2D shapes for all weight matrices.
79fn infer_all_tensor_shapes(parameters: &[(String, Tensor)]) -> HashMap<String, Vec<usize>> {
80    let mut shapes = HashMap::new();
81
82    // Find hidden_size from a norm weight (always 1D [H])
83    let hidden_size = parameters
84        .iter()
85        .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
86        .map_or(0, |(_, t)| t.len());
87
88    for (name, tensor) in parameters {
89        let numel = tensor.len();
90        let shape = if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
91            vec![numel]
92        } else if hidden_size > 0 && numel % hidden_size == 0 {
93            let other_dim = numel / hidden_size;
94            // For down_proj: [hidden_size, intermediate_size] — hidden is smaller dim
95            // For gate/up_proj: [intermediate_size, hidden_size] — hidden is smaller dim
96            // For q/o_proj: [hidden_size, hidden_size] — square
97            // For k/v_proj: [kv_dim, hidden_size] — kv_dim < hidden
98            // For embed/lm_head: [vocab_size, hidden_size]
99            if name.ends_with("down_proj.weight") {
100                vec![hidden_size, other_dim]
101            } else {
102                vec![other_dim, hidden_size]
103            }
104        } else {
105            vec![numel]
106        };
107        shapes.insert(name.clone(), shape);
108    }
109    shapes
110}
111
112/// Save model in SafeTensors format (HuggingFace compatible)
113fn save_safetensors(model: &Model, path: &Path) -> Result<()> {
114    // ALB-086: Compute proper 2D shapes for HuggingFace compatibility
115    let shapes = infer_all_tensor_shapes(&model.parameters);
116
117    // Collect tensor data with proper lifetime management
118    let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = model
119        .parameters
120        .iter()
121        .map(|(name, tensor)| {
122            let data = tensor.data();
123            let bytes: Vec<u8> =
124                bytemuck::cast_slice(data.as_slice().expect("tensor data must be contiguous"))
125                    .to_vec();
126            let shape = shapes.get(name).cloned().unwrap_or_else(|| vec![tensor.len()]);
127            (name.clone(), bytes, shape)
128        })
129        .collect();
130
131    // Create TensorViews from collected data
132    let views: Vec<(&str, TensorView<'_>)> = tensor_data
133        .iter()
134        .map(|(name, bytes, shape)| {
135            let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
136                .expect("TensorView construction must not fail for valid F32 data");
137            (name.as_str(), view)
138        })
139        .collect();
140
141    // Create metadata with model info
142    let mut metadata = HashMap::new();
143    metadata.insert("name".to_string(), model.metadata.name.clone());
144    metadata.insert("architecture".to_string(), model.metadata.architecture.clone());
145    metadata.insert("version".to_string(), model.metadata.version.clone());
146
147    // Serialize to SafeTensors format
148    let safetensor_bytes = safetensors::serialize(views, Some(metadata))
149        .map_err(|e| Error::Serialization(format!("SafeTensors serialization failed: {e}")))?;
150
151    // Write to file
152    std::fs::write(path, safetensor_bytes)?;
153
154    Ok(())
155}
156
157/// ALB-096: Save model in APR format (sovereign stack universal format).
158///
159/// Uses `AprWriter` for atomic single-file checkpoints with proper 2D shapes
160/// and model metadata. Supports both model weights and `__training__.*` tensors.
161fn save_apr(model: &Model, path: &Path) -> Result<()> {
162    use aprender::serialization::apr::AprWriter;
163    use serde_json::Value as JsonValue;
164
165    let mut writer = AprWriter::new();
166
167    // Embed model metadata (keys must match AprWriter's well-known key mapping)
168    writer.set_metadata("model_name", JsonValue::String(model.metadata.name.clone()));
169    writer.set_metadata("architecture", JsonValue::String(model.metadata.architecture.clone()));
170    writer.set_metadata("version", JsonValue::String(model.metadata.version.clone()));
171    writer.set_metadata("format", JsonValue::String("entrenar-checkpoint".into()));
172
173    // Compute proper 2D shapes (same logic as SafeTensors — ALB-086)
174    let shapes = infer_all_tensor_shapes(&model.parameters);
175
176    for (name, tensor) in &model.parameters {
177        let data = tensor.data();
178        let slice = data.as_slice().expect("tensor data must be contiguous");
179        let shape = shapes.get(name).cloned().unwrap_or_else(|| vec![tensor.len()]);
180        writer.add_tensor_f32(name, shape, slice);
181    }
182
183    writer
184        .write(path)
185        .map_err(|e| Error::Serialization(format!("APR serialization failed: {e}")))?;
186
187    Ok(())
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::io::{Model, ModelMetadata};
194    use crate::Tensor;
195    use tempfile::NamedTempFile;
196
197    #[test]
198    fn test_save_model_json() {
199        let params = vec![
200            ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
201            ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
202        ];
203
204        let model = Model::new(ModelMetadata::new("test-model", "linear"), params);
205        let config = SaveConfig::new(ModelFormat::Json);
206
207        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
208        save_model(&model, temp_file.path(), &config).expect("save should succeed");
209
210        // Verify file was created and has content
211        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
212        assert!(!content.is_empty());
213        assert!(content.contains("test-model"));
214        assert!(content.contains("linear"));
215    }
216
217    #[test]
218    fn test_save_model_yaml() {
219        let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true))];
220
221        let model = Model::new(ModelMetadata::new("test", "simple"), params);
222        let config = SaveConfig::new(ModelFormat::Yaml);
223
224        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
225        save_model(&model, temp_file.path(), &config).expect("save should succeed");
226
227        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
228        assert!(content.contains("test"));
229        assert!(content.contains("simple"));
230    }
231
232    #[test]
233    fn test_save_model_json_pretty() {
234        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
235        let model = Model::new(ModelMetadata::new("pretty-test", "test"), params);
236        let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
237
238        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
239        save_model(&model, temp_file.path(), &config).expect("save should succeed");
240
241        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
242        // Pretty JSON should have newlines
243        assert!(content.contains('\n'));
244    }
245
246    #[test]
247    fn test_save_model_json_compact() {
248        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
249        let model = Model::new(ModelMetadata::new("compact-test", "test"), params);
250        let config = SaveConfig::new(ModelFormat::Json).with_pretty(false);
251
252        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
253        save_model(&model, temp_file.path(), &config).expect("save should succeed");
254
255        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
256        // Compact JSON should be single line (minus trailing)
257        let lines: Vec<&str> = content.lines().collect();
258        assert_eq!(lines.len(), 1);
259    }
260
261    #[test]
262    fn test_save_model_empty_params() {
263        let model = Model::new(ModelMetadata::new("empty", "test"), vec![]);
264        let config = SaveConfig::new(ModelFormat::Json);
265
266        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
267        save_model(&model, temp_file.path(), &config).expect("save should succeed");
268
269        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
270        assert!(content.contains("empty"));
271    }
272
273    #[test]
274    fn test_save_model_large_tensor() {
275        let large_data: Vec<f32> = (0..1000).map(|i| i as f32 * 0.001).collect();
276        let params = vec![("large".to_string(), Tensor::from_vec(large_data, false))];
277        let model = Model::new(ModelMetadata::new("large", "test"), params);
278        let config = SaveConfig::new(ModelFormat::Json);
279
280        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
281        save_model(&model, temp_file.path(), &config).expect("save should succeed");
282
283        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
284        assert!(content.len() > 1000);
285    }
286
287    #[test]
288    fn test_save_config_builder() {
289        let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
290        assert!(config.pretty);
291        assert_eq!(config.format, ModelFormat::Json);
292    }
293
294    #[test]
295    fn test_save_model_with_compress_option() {
296        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
297        let model = Model::new(ModelMetadata::new("compress-test", "test"), params);
298        let config = SaveConfig::new(ModelFormat::Json).with_compress(true);
299
300        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
301        // Currently compress is not implemented, but we can still save
302        save_model(&model, temp_file.path(), &config).expect("save should succeed");
303
304        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
305        assert!(content.contains("compress-test"));
306    }
307
308    #[test]
309    fn test_save_model_multiple_tensors() {
310        let params = vec![
311            ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
312            ("layer1.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
313            ("layer2.weight".to_string(), Tensor::from_vec(vec![3.0, 4.0], false)),
314        ];
315        let model = Model::new(ModelMetadata::new("multi", "deep"), params);
316        let config = SaveConfig::new(ModelFormat::Yaml);
317
318        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
319        save_model(&model, temp_file.path(), &config).expect("save should succeed");
320
321        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
322        assert!(content.contains("layer1.weight"));
323        assert!(content.contains("layer2.weight"));
324    }
325
326    #[test]
327    fn test_save_model_with_metadata() {
328        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
329        let meta = ModelMetadata::new("meta-test", "test")
330            .with_custom("version", serde_json::json!("1.0.0"))
331            .with_custom("author", serde_json::json!("test"));
332        let model = Model::new(meta, params);
333        let config = SaveConfig::new(ModelFormat::Json);
334
335        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
336        save_model(&model, temp_file.path(), &config).expect("save should succeed");
337
338        let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
339        assert!(content.contains("version"));
340    }
341
342    #[test]
343    fn test_save_config_default() {
344        let config = SaveConfig::default();
345        assert_eq!(config.format, ModelFormat::Json);
346        assert!(config.pretty);
347        assert!(!config.compress);
348    }
349
350    #[test]
351    fn test_save_model_invalid_path() {
352        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
353        let model = Model::new(ModelMetadata::new("test", "test"), params);
354        let config = SaveConfig::new(ModelFormat::Json);
355
356        // Try to save to an invalid directory
357        let result = save_model(&model, "/nonexistent/directory/model.json", &config);
358        assert!(result.is_err());
359    }
360
361    #[test]
362    fn test_save_model_safetensors() {
363        let params = vec![
364            ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
365            ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
366        ];
367
368        let model = Model::new(ModelMetadata::new("safetensor-test", "linear"), params);
369        let config = SaveConfig::new(ModelFormat::SafeTensors);
370
371        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
372        save_model(&model, temp_file.path(), &config).expect("save should succeed");
373
374        // Verify file was created and is binary (starts with safetensors magic)
375        let content = std::fs::read(temp_file.path()).expect("file read should succeed");
376        assert!(!content.is_empty());
377        // SafeTensors files start with a header length (8 bytes)
378        assert!(content.len() > 8);
379    }
380
381    #[test]
382    fn test_save_model_safetensors_can_be_loaded() {
383        let params = vec![
384            ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
385            ("layer1.bias".to_string(), Tensor::from_vec(vec![0.5], false)),
386        ];
387
388        let model = Model::new(ModelMetadata::new("roundtrip-test", "mlp"), params);
389        let config = SaveConfig::new(ModelFormat::SafeTensors);
390
391        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
392        save_model(&model, temp_file.path(), &config).expect("save should succeed");
393
394        // Verify we can load it back with safetensors crate
395        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
396        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
397
398        // Check tensor names exist - names() returns Vec<&str>
399        let names = loaded.names();
400        assert!(names.contains(&"layer1.weight"));
401        assert!(names.contains(&"layer1.bias"));
402
403        // Check tensor data
404        let weight = loaded.tensor("layer1.weight").expect("load should succeed");
405        assert_eq!(weight.shape(), &[4]);
406        let weight_data: &[f32] = bytemuck::cast_slice(weight.data());
407        assert_eq!(weight_data, &[1.0, 2.0, 3.0, 4.0]);
408    }
409
410    #[test]
411    fn test_save_safetensors_metadata() {
412        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
413        let model = Model::new(ModelMetadata::new("meta-model", "transformer"), params);
414        let config = SaveConfig::new(ModelFormat::SafeTensors);
415
416        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
417        save_model(&model, temp_file.path(), &config).expect("save should succeed");
418
419        // Load and check metadata using read_metadata
420        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
421        let (_, st_metadata) =
422            safetensors::SafeTensors::read_metadata(&data).expect("deserialization should succeed");
423
424        let metadata = st_metadata.metadata();
425        assert!(metadata.is_some());
426        let meta = metadata.as_ref().expect("operation should succeed");
427        assert_eq!(meta.get("name").expect("key should exist"), "meta-model");
428        assert_eq!(meta.get("architecture").expect("key should exist"), "transformer");
429    }
430
431    #[test]
432    fn test_save_safetensors_large_tensor() {
433        let large_data: Vec<f32> = (0..10000).map(|i| i as f32 * 0.001).collect();
434        let params =
435            vec![("large_weights".to_string(), Tensor::from_vec(large_data.clone(), false))];
436        let model = Model::new(ModelMetadata::new("large", "test"), params);
437        let config = SaveConfig::new(ModelFormat::SafeTensors);
438
439        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
440        save_model(&model, temp_file.path(), &config).expect("save should succeed");
441
442        // Verify data integrity
443        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
444        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
445        let tensor = loaded.tensor("large_weights").expect("load should succeed");
446        let tensor_data: &[f32] = bytemuck::cast_slice(tensor.data());
447        assert_eq!(tensor_data.len(), 10000);
448        assert!((tensor_data[0] - 0.0).abs() < 1e-6);
449        assert!((tensor_data[9999] - 9.999).abs() < 1e-3);
450    }
451
452    #[test]
453    fn test_save_safetensors_invalid_path() {
454        let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
455        let model = Model::new(ModelMetadata::new("test", "test"), params);
456        let config = SaveConfig::new(ModelFormat::SafeTensors);
457
458        let result = save_model(&model, "/nonexistent/directory/model.safetensors", &config);
459        assert!(result.is_err());
460    }
461
462    #[test]
463    fn test_save_safetensors_empty_params() {
464        let model = Model::new(ModelMetadata::new("empty", "test"), vec![]);
465        let config = SaveConfig::new(ModelFormat::SafeTensors);
466
467        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
468        save_model(&model, temp_file.path(), &config).expect("save should succeed");
469
470        // Should still create valid file with metadata
471        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
472        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
473        assert_eq!(loaded.len(), 0);
474    }
475
476    #[test]
477    fn test_save_safetensors_multiple_tensors() {
478        let params = vec![
479            ("encoder.layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
480            ("encoder.layer1.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
481            ("encoder.layer2.weight".to_string(), Tensor::from_vec(vec![3.0, 4.0, 5.0], false)),
482            ("decoder.layer1.weight".to_string(), Tensor::from_vec(vec![6.0, 7.0], false)),
483        ];
484        let model = Model::new(ModelMetadata::new("encoder-decoder", "transformer"), params);
485        let config = SaveConfig::new(ModelFormat::SafeTensors);
486
487        let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
488        save_model(&model, temp_file.path(), &config).expect("save should succeed");
489
490        let data = std::fs::read(temp_file.path()).expect("file read should succeed");
491        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
492        assert_eq!(loaded.len(), 4);
493
494        // names() returns Vec<&str> directly
495        let names = loaded.names();
496        assert!(names.contains(&"encoder.layer1.weight"));
497        assert!(names.contains(&"decoder.layer1.weight"));
498    }
499
500    /// ALB-086: Verify SafeTensors saves proper 2D shapes for LlamaForCausalLM weights.
501    #[test]
502    fn test_safetensors_saves_2d_shapes() {
503        let hidden = 64;
504        let intermediate = 128;
505        let vocab = 256;
506
507        let params = vec![
508            ("model.embed_tokens.weight".to_string(), Tensor::zeros(vocab * hidden, false)),
509            ("model.norm.weight".to_string(), Tensor::zeros(hidden, false)),
510            ("model.layers.0.input_layernorm.weight".to_string(), Tensor::zeros(hidden, false)),
511            (
512                "model.layers.0.post_attention_layernorm.weight".to_string(),
513                Tensor::zeros(hidden, false),
514            ),
515            (
516                "model.layers.0.self_attn.q_proj.weight".to_string(),
517                Tensor::zeros(hidden * hidden, false),
518            ),
519            (
520                "model.layers.0.self_attn.k_proj.weight".to_string(),
521                Tensor::zeros(16 * hidden, false),
522            ),
523            (
524                "model.layers.0.self_attn.v_proj.weight".to_string(),
525                Tensor::zeros(16 * hidden, false),
526            ),
527            (
528                "model.layers.0.self_attn.o_proj.weight".to_string(),
529                Tensor::zeros(hidden * hidden, false),
530            ),
531            (
532                "model.layers.0.mlp.gate_proj.weight".to_string(),
533                Tensor::zeros(intermediate * hidden, false),
534            ),
535            (
536                "model.layers.0.mlp.up_proj.weight".to_string(),
537                Tensor::zeros(intermediate * hidden, false),
538            ),
539            (
540                "model.layers.0.mlp.down_proj.weight".to_string(),
541                Tensor::zeros(hidden * intermediate, false),
542            ),
543        ];
544
545        let metadata = ModelMetadata::new("test", "LlamaForCausalLM");
546        let model = Model::new(metadata, params);
547        let config =
548            crate::io::format::SaveConfig::new(crate::io::format::ModelFormat::SafeTensors);
549        let temp = NamedTempFile::new().unwrap();
550        save_model(&model, temp.path(), &config).unwrap();
551
552        let data = std::fs::read(temp.path()).unwrap();
553        let loaded = safetensors::SafeTensors::deserialize(&data).unwrap();
554
555        // Norm weights should be 1D
556        assert_eq!(loaded.tensor("model.norm.weight").unwrap().shape(), &[hidden]);
557        assert_eq!(
558            loaded.tensor("model.layers.0.input_layernorm.weight").unwrap().shape(),
559            &[hidden]
560        );
561
562        // Projection weights should be 2D
563        assert_eq!(loaded.tensor("model.embed_tokens.weight").unwrap().shape(), &[vocab, hidden]);
564        assert_eq!(
565            loaded.tensor("model.layers.0.self_attn.q_proj.weight").unwrap().shape(),
566            &[hidden, hidden]
567        );
568        assert_eq!(
569            loaded.tensor("model.layers.0.self_attn.k_proj.weight").unwrap().shape(),
570            &[16, hidden]
571        );
572        assert_eq!(
573            loaded.tensor("model.layers.0.mlp.gate_proj.weight").unwrap().shape(),
574            &[intermediate, hidden]
575        );
576        assert_eq!(
577            loaded.tensor("model.layers.0.mlp.down_proj.weight").unwrap().shape(),
578            &[hidden, intermediate]
579        );
580    }
581}