1use super::format::{ModelFormat, SaveConfig};
4use super::model::Model;
5use crate::Tensor;
6use crate::{Error, Result};
7use safetensors::tensor::{Dtype, TensorView};
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::Write;
11use std::path::Path;
12
13pub fn save_model(model: &Model, path: impl AsRef<Path>, config: &SaveConfig) -> Result<()> {
36 let path = path.as_ref();
37
38 match config.format {
39 ModelFormat::SafeTensors => save_safetensors(model, path),
40 ModelFormat::Apr => save_apr(model, path),
41 ModelFormat::Json => save_json(model, path, config.pretty),
42 ModelFormat::Yaml => save_yaml(model, path),
43 #[cfg(feature = "gguf")]
44 ModelFormat::Gguf => Err(Error::Serialization(
45 "GGUF format not yet implemented. Enable 'gguf' feature and use realizar integration."
46 .to_string(),
47 )),
48 }
49}
50
51fn save_json(model: &Model, path: &Path, pretty: bool) -> Result<()> {
53 let state = model.to_state();
54 let data = if pretty {
55 serde_json::to_string_pretty(&state)
56 .map_err(|e| Error::Serialization(format!("JSON serialization failed: {e}")))?
57 } else {
58 serde_json::to_string(&state)
59 .map_err(|e| Error::Serialization(format!("JSON serialization failed: {e}")))?
60 };
61 let mut file = File::create(path)?;
62 file.write_all(data.as_bytes())?;
63 Ok(())
64}
65
66fn save_yaml(model: &Model, path: &Path) -> Result<()> {
68 let state = model.to_state();
69 let data = serde_yaml::to_string(&state)
70 .map_err(|e| Error::Serialization(format!("YAML serialization failed: {e}")))?;
71 let mut file = File::create(path)?;
72 file.write_all(data.as_bytes())?;
73 Ok(())
74}
75
76pub(crate) fn infer_all_tensor_shapes(
80 parameters: &[(String, Tensor)],
81) -> HashMap<String, Vec<usize>> {
82 let mut shapes = HashMap::new();
83
84 let hidden_size = parameters
86 .iter()
87 .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
88 .map_or(0, |(_, t)| t.len());
89
90 for (name, tensor) in parameters {
91 let numel = tensor.len();
92 let shape = if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
93 vec![numel]
94 } else if hidden_size > 0 && numel % hidden_size == 0 {
95 let other_dim = numel / hidden_size;
96 if name.ends_with("down_proj.weight") {
102 vec![hidden_size, other_dim]
103 } else {
104 vec![other_dim, hidden_size]
105 }
106 } else {
107 vec![numel]
108 };
109 shapes.insert(name.clone(), shape);
110 }
111 shapes
112}
113
114fn save_safetensors(model: &Model, path: &Path) -> Result<()> {
116 let shapes = infer_all_tensor_shapes(&model.parameters);
118
119 let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = model
121 .parameters
122 .iter()
123 .map(|(name, tensor)| {
124 let data = tensor.data();
125 let bytes: Vec<u8> =
126 bytemuck::cast_slice(data.as_slice().expect("tensor data must be contiguous"))
127 .to_vec();
128 let shape = shapes.get(name).cloned().unwrap_or_else(|| vec![tensor.len()]);
129 (name.clone(), bytes, shape)
130 })
131 .collect();
132
133 let views: Vec<(&str, TensorView<'_>)> = tensor_data
135 .iter()
136 .map(|(name, bytes, shape)| {
137 let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
138 .expect("TensorView construction must not fail for valid F32 data");
139 (name.as_str(), view)
140 })
141 .collect();
142
143 let mut metadata = HashMap::new();
145 metadata.insert("name".to_string(), model.metadata.name.clone());
146 metadata.insert("architecture".to_string(), model.metadata.architecture.clone());
147 metadata.insert("version".to_string(), model.metadata.version.clone());
148
149 let safetensor_bytes = safetensors::serialize(views, Some(metadata))
151 .map_err(|e| Error::Serialization(format!("SafeTensors serialization failed: {e}")))?;
152
153 std::fs::write(path, safetensor_bytes)?;
155
156 Ok(())
157}
158
159fn save_apr(model: &Model, path: &Path) -> Result<()> {
164 use aprender::serialization::apr::AprWriter;
165 use serde_json::Value as JsonValue;
166
167 let mut writer = AprWriter::new();
168
169 writer.set_metadata("model_name", JsonValue::String(model.metadata.name.clone()));
171 writer.set_metadata("architecture", JsonValue::String(model.metadata.architecture.clone()));
172 writer.set_metadata("version", JsonValue::String(model.metadata.version.clone()));
173 writer.set_metadata("format", JsonValue::String("entrenar-checkpoint".into()));
174
175 let shapes = infer_all_tensor_shapes(&model.parameters);
177
178 for (name, tensor) in &model.parameters {
179 let data = tensor.data();
180 let slice = data.as_slice().expect("tensor data must be contiguous");
181 let shape = shapes.get(name).cloned().unwrap_or_else(|| vec![tensor.len()]);
182 writer.add_tensor_f32(name, shape, slice);
183 }
184
185 writer
186 .write(path)
187 .map_err(|e| Error::Serialization(format!("APR serialization failed: {e}")))?;
188
189 Ok(())
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use crate::io::{Model, ModelMetadata};
196 use crate::Tensor;
197 use tempfile::NamedTempFile;
198
199 #[test]
200 fn test_save_model_json() {
201 let params = vec![
202 ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
203 ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
204 ];
205
206 let model = Model::new(ModelMetadata::new("test-model", "linear"), params);
207 let config = SaveConfig::new(ModelFormat::Json);
208
209 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
210 save_model(&model, temp_file.path(), &config).expect("save should succeed");
211
212 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
214 assert!(!content.is_empty());
215 assert!(content.contains("test-model"));
216 assert!(content.contains("linear"));
217 }
218
219 #[test]
220 fn test_save_model_yaml() {
221 let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true))];
222
223 let model = Model::new(ModelMetadata::new("test", "simple"), params);
224 let config = SaveConfig::new(ModelFormat::Yaml);
225
226 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
227 save_model(&model, temp_file.path(), &config).expect("save should succeed");
228
229 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
230 assert!(content.contains("test"));
231 assert!(content.contains("simple"));
232 }
233
234 #[test]
235 fn test_save_model_json_pretty() {
236 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
237 let model = Model::new(ModelMetadata::new("pretty-test", "test"), params);
238 let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
239
240 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
241 save_model(&model, temp_file.path(), &config).expect("save should succeed");
242
243 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
244 assert!(content.contains('\n'));
246 }
247
248 #[test]
249 fn test_save_model_json_compact() {
250 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
251 let model = Model::new(ModelMetadata::new("compact-test", "test"), params);
252 let config = SaveConfig::new(ModelFormat::Json).with_pretty(false);
253
254 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
255 save_model(&model, temp_file.path(), &config).expect("save should succeed");
256
257 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
258 let lines: Vec<&str> = content.lines().collect();
260 assert_eq!(lines.len(), 1);
261 }
262
263 #[test]
264 fn test_save_model_empty_params() {
265 let model = Model::new(ModelMetadata::new("empty", "test"), vec![]);
266 let config = SaveConfig::new(ModelFormat::Json);
267
268 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
269 save_model(&model, temp_file.path(), &config).expect("save should succeed");
270
271 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
272 assert!(content.contains("empty"));
273 }
274
275 #[test]
276 fn test_save_model_large_tensor() {
277 let large_data: Vec<f32> = (0..1000).map(|i| i as f32 * 0.001).collect();
278 let params = vec![("large".to_string(), Tensor::from_vec(large_data, false))];
279 let model = Model::new(ModelMetadata::new("large", "test"), params);
280 let config = SaveConfig::new(ModelFormat::Json);
281
282 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
283 save_model(&model, temp_file.path(), &config).expect("save should succeed");
284
285 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
286 assert!(content.len() > 1000);
287 }
288
289 #[test]
290 fn test_save_config_builder() {
291 let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
292 assert!(config.pretty);
293 assert_eq!(config.format, ModelFormat::Json);
294 }
295
296 #[test]
297 fn test_save_model_with_compress_option() {
298 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
299 let model = Model::new(ModelMetadata::new("compress-test", "test"), params);
300 let config = SaveConfig::new(ModelFormat::Json).with_compress(true);
301
302 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
303 save_model(&model, temp_file.path(), &config).expect("save should succeed");
305
306 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
307 assert!(content.contains("compress-test"));
308 }
309
310 #[test]
311 fn test_save_model_multiple_tensors() {
312 let params = vec![
313 ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
314 ("layer1.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
315 ("layer2.weight".to_string(), Tensor::from_vec(vec![3.0, 4.0], false)),
316 ];
317 let model = Model::new(ModelMetadata::new("multi", "deep"), params);
318 let config = SaveConfig::new(ModelFormat::Yaml);
319
320 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
321 save_model(&model, temp_file.path(), &config).expect("save should succeed");
322
323 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
324 assert!(content.contains("layer1.weight"));
325 assert!(content.contains("layer2.weight"));
326 }
327
328 #[test]
329 fn test_save_model_with_metadata() {
330 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
331 let meta = ModelMetadata::new("meta-test", "test")
332 .with_custom("version", serde_json::json!("1.0.0"))
333 .with_custom("author", serde_json::json!("test"));
334 let model = Model::new(meta, params);
335 let config = SaveConfig::new(ModelFormat::Json);
336
337 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
338 save_model(&model, temp_file.path(), &config).expect("save should succeed");
339
340 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
341 assert!(content.contains("version"));
342 }
343
344 #[test]
345 fn test_save_config_default() {
346 let config = SaveConfig::default();
347 assert_eq!(config.format, ModelFormat::Json);
348 assert!(config.pretty);
349 assert!(!config.compress);
350 }
351
352 #[test]
353 fn test_save_model_invalid_path() {
354 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
355 let model = Model::new(ModelMetadata::new("test", "test"), params);
356 let config = SaveConfig::new(ModelFormat::Json);
357
358 let result = save_model(&model, "/nonexistent/directory/model.json", &config);
360 assert!(result.is_err());
361 }
362
363 #[test]
364 fn test_save_model_safetensors() {
365 let params = vec![
366 ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
367 ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
368 ];
369
370 let model = Model::new(ModelMetadata::new("safetensor-test", "linear"), params);
371 let config = SaveConfig::new(ModelFormat::SafeTensors);
372
373 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
374 save_model(&model, temp_file.path(), &config).expect("save should succeed");
375
376 let content = std::fs::read(temp_file.path()).expect("file read should succeed");
378 assert!(!content.is_empty());
379 assert!(content.len() > 8);
381 }
382
383 #[test]
384 fn test_save_model_safetensors_can_be_loaded() {
385 let params = vec![
386 ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
387 ("layer1.bias".to_string(), Tensor::from_vec(vec![0.5], false)),
388 ];
389
390 let model = Model::new(ModelMetadata::new("roundtrip-test", "mlp"), params);
391 let config = SaveConfig::new(ModelFormat::SafeTensors);
392
393 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
394 save_model(&model, temp_file.path(), &config).expect("save should succeed");
395
396 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
398 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
399
400 let names = loaded.names();
402 assert!(names.contains(&"layer1.weight"));
403 assert!(names.contains(&"layer1.bias"));
404
405 let weight = loaded.tensor("layer1.weight").expect("load should succeed");
407 assert_eq!(weight.shape(), &[4]);
408 let weight_data: &[f32] = bytemuck::cast_slice(weight.data());
409 assert_eq!(weight_data, &[1.0, 2.0, 3.0, 4.0]);
410 }
411
412 #[test]
413 fn test_save_safetensors_metadata() {
414 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
415 let model = Model::new(ModelMetadata::new("meta-model", "transformer"), params);
416 let config = SaveConfig::new(ModelFormat::SafeTensors);
417
418 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
419 save_model(&model, temp_file.path(), &config).expect("save should succeed");
420
421 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
423 let (_, st_metadata) =
424 safetensors::SafeTensors::read_metadata(&data).expect("deserialization should succeed");
425
426 let metadata = st_metadata.metadata();
427 assert!(metadata.is_some());
428 let meta = metadata.as_ref().expect("operation should succeed");
429 assert_eq!(meta.get("name").expect("key should exist"), "meta-model");
430 assert_eq!(meta.get("architecture").expect("key should exist"), "transformer");
431 }
432
433 #[test]
434 fn test_save_safetensors_large_tensor() {
435 let large_data: Vec<f32> = (0..10000).map(|i| i as f32 * 0.001).collect();
436 let params =
437 vec![("large_weights".to_string(), Tensor::from_vec(large_data.clone(), false))];
438 let model = Model::new(ModelMetadata::new("large", "test"), params);
439 let config = SaveConfig::new(ModelFormat::SafeTensors);
440
441 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
442 save_model(&model, temp_file.path(), &config).expect("save should succeed");
443
444 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
446 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
447 let tensor = loaded.tensor("large_weights").expect("load should succeed");
448 let tensor_data: &[f32] = bytemuck::cast_slice(tensor.data());
449 assert_eq!(tensor_data.len(), 10000);
450 assert!((tensor_data[0] - 0.0).abs() < 1e-6);
451 assert!((tensor_data[9999] - 9.999).abs() < 1e-3);
452 }
453
454 #[test]
455 fn test_save_safetensors_invalid_path() {
456 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
457 let model = Model::new(ModelMetadata::new("test", "test"), params);
458 let config = SaveConfig::new(ModelFormat::SafeTensors);
459
460 let result = save_model(&model, "/nonexistent/directory/model.safetensors", &config);
461 assert!(result.is_err());
462 }
463
464 #[test]
465 fn test_save_safetensors_empty_params() {
466 let model = Model::new(ModelMetadata::new("empty", "test"), vec![]);
467 let config = SaveConfig::new(ModelFormat::SafeTensors);
468
469 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
470 save_model(&model, temp_file.path(), &config).expect("save should succeed");
471
472 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
474 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
475 assert_eq!(loaded.len(), 0);
476 }
477
478 #[test]
479 fn test_save_safetensors_multiple_tensors() {
480 let params = vec![
481 ("encoder.layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
482 ("encoder.layer1.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
483 ("encoder.layer2.weight".to_string(), Tensor::from_vec(vec![3.0, 4.0, 5.0], false)),
484 ("decoder.layer1.weight".to_string(), Tensor::from_vec(vec![6.0, 7.0], false)),
485 ];
486 let model = Model::new(ModelMetadata::new("encoder-decoder", "transformer"), params);
487 let config = SaveConfig::new(ModelFormat::SafeTensors);
488
489 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
490 save_model(&model, temp_file.path(), &config).expect("save should succeed");
491
492 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
493 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
494 assert_eq!(loaded.len(), 4);
495
496 let names = loaded.names();
498 assert!(names.contains(&"encoder.layer1.weight"));
499 assert!(names.contains(&"decoder.layer1.weight"));
500 }
501
502 #[test]
504 fn test_safetensors_saves_2d_shapes() {
505 let hidden = 64;
506 let intermediate = 128;
507 let vocab = 256;
508
509 let params = vec![
510 ("model.embed_tokens.weight".to_string(), Tensor::zeros(vocab * hidden, false)),
511 ("model.norm.weight".to_string(), Tensor::zeros(hidden, false)),
512 ("model.layers.0.input_layernorm.weight".to_string(), Tensor::zeros(hidden, false)),
513 (
514 "model.layers.0.post_attention_layernorm.weight".to_string(),
515 Tensor::zeros(hidden, false),
516 ),
517 (
518 "model.layers.0.self_attn.q_proj.weight".to_string(),
519 Tensor::zeros(hidden * hidden, false),
520 ),
521 (
522 "model.layers.0.self_attn.k_proj.weight".to_string(),
523 Tensor::zeros(16 * hidden, false),
524 ),
525 (
526 "model.layers.0.self_attn.v_proj.weight".to_string(),
527 Tensor::zeros(16 * hidden, false),
528 ),
529 (
530 "model.layers.0.self_attn.o_proj.weight".to_string(),
531 Tensor::zeros(hidden * hidden, false),
532 ),
533 (
534 "model.layers.0.mlp.gate_proj.weight".to_string(),
535 Tensor::zeros(intermediate * hidden, false),
536 ),
537 (
538 "model.layers.0.mlp.up_proj.weight".to_string(),
539 Tensor::zeros(intermediate * hidden, false),
540 ),
541 (
542 "model.layers.0.mlp.down_proj.weight".to_string(),
543 Tensor::zeros(hidden * intermediate, false),
544 ),
545 ];
546
547 let metadata = ModelMetadata::new("test", "LlamaForCausalLM");
548 let model = Model::new(metadata, params);
549 let config =
550 crate::io::format::SaveConfig::new(crate::io::format::ModelFormat::SafeTensors);
551 let temp = NamedTempFile::new().unwrap();
552 save_model(&model, temp.path(), &config).unwrap();
553
554 let data = std::fs::read(temp.path()).unwrap();
555 let loaded = safetensors::SafeTensors::deserialize(&data).unwrap();
556
557 assert_eq!(loaded.tensor("model.norm.weight").unwrap().shape(), &[hidden]);
559 assert_eq!(
560 loaded.tensor("model.layers.0.input_layernorm.weight").unwrap().shape(),
561 &[hidden]
562 );
563
564 assert_eq!(loaded.tensor("model.embed_tokens.weight").unwrap().shape(), &[vocab, hidden]);
566 assert_eq!(
567 loaded.tensor("model.layers.0.self_attn.q_proj.weight").unwrap().shape(),
568 &[hidden, hidden]
569 );
570 assert_eq!(
571 loaded.tensor("model.layers.0.self_attn.k_proj.weight").unwrap().shape(),
572 &[16, hidden]
573 );
574 assert_eq!(
575 loaded.tensor("model.layers.0.mlp.gate_proj.weight").unwrap().shape(),
576 &[intermediate, hidden]
577 );
578 assert_eq!(
579 loaded.tensor("model.layers.0.mlp.down_proj.weight").unwrap().shape(),
580 &[hidden, intermediate]
581 );
582 }
583}