1use super::format::{ModelFormat, SaveConfig};
4use super::model::Model;
5use crate::Tensor;
6use crate::{Error, Result};
7use safetensors::tensor::{Dtype, TensorView};
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::Write;
11use std::path::Path;
12
13pub fn save_model(model: &Model, path: impl AsRef<Path>, config: &SaveConfig) -> Result<()> {
36 let path = path.as_ref();
37
38 match config.format {
39 ModelFormat::SafeTensors => save_safetensors(model, path),
40 ModelFormat::Apr => save_apr(model, path),
41 ModelFormat::Json => save_json(model, path, config.pretty),
42 ModelFormat::Yaml => save_yaml(model, path),
43 #[cfg(feature = "gguf")]
44 ModelFormat::Gguf => Err(Error::Serialization(
45 "GGUF format not yet implemented. Enable 'gguf' feature and use realizar integration."
46 .to_string(),
47 )),
48 }
49}
50
51fn save_json(model: &Model, path: &Path, pretty: bool) -> Result<()> {
53 let state = model.to_state();
54 let data = if pretty {
55 serde_json::to_string_pretty(&state)
56 .map_err(|e| Error::Serialization(format!("JSON serialization failed: {e}")))?
57 } else {
58 serde_json::to_string(&state)
59 .map_err(|e| Error::Serialization(format!("JSON serialization failed: {e}")))?
60 };
61 let mut file = File::create(path)?;
62 file.write_all(data.as_bytes())?;
63 Ok(())
64}
65
66fn save_yaml(model: &Model, path: &Path) -> Result<()> {
68 let state = model.to_state();
69 let data = serde_yaml::to_string(&state)
70 .map_err(|e| Error::Serialization(format!("YAML serialization failed: {e}")))?;
71 let mut file = File::create(path)?;
72 file.write_all(data.as_bytes())?;
73 Ok(())
74}
75
76fn infer_all_tensor_shapes(parameters: &[(String, Tensor)]) -> HashMap<String, Vec<usize>> {
80 let mut shapes = HashMap::new();
81
82 let hidden_size = parameters
84 .iter()
85 .find(|(n, _)| n.ends_with("layernorm.weight") || n == "model.norm.weight")
86 .map_or(0, |(_, t)| t.len());
87
88 for (name, tensor) in parameters {
89 let numel = tensor.len();
90 let shape = if name.ends_with("layernorm.weight") || name == "model.norm.weight" {
91 vec![numel]
92 } else if hidden_size > 0 && numel % hidden_size == 0 {
93 let other_dim = numel / hidden_size;
94 if name.ends_with("down_proj.weight") {
100 vec![hidden_size, other_dim]
101 } else {
102 vec![other_dim, hidden_size]
103 }
104 } else {
105 vec![numel]
106 };
107 shapes.insert(name.clone(), shape);
108 }
109 shapes
110}
111
112fn save_safetensors(model: &Model, path: &Path) -> Result<()> {
114 let shapes = infer_all_tensor_shapes(&model.parameters);
116
117 let tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = model
119 .parameters
120 .iter()
121 .map(|(name, tensor)| {
122 let data = tensor.data();
123 let bytes: Vec<u8> =
124 bytemuck::cast_slice(data.as_slice().expect("tensor data must be contiguous"))
125 .to_vec();
126 let shape = shapes.get(name).cloned().unwrap_or_else(|| vec![tensor.len()]);
127 (name.clone(), bytes, shape)
128 })
129 .collect();
130
131 let views: Vec<(&str, TensorView<'_>)> = tensor_data
133 .iter()
134 .map(|(name, bytes, shape)| {
135 let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
136 .expect("TensorView construction must not fail for valid F32 data");
137 (name.as_str(), view)
138 })
139 .collect();
140
141 let mut metadata = HashMap::new();
143 metadata.insert("name".to_string(), model.metadata.name.clone());
144 metadata.insert("architecture".to_string(), model.metadata.architecture.clone());
145 metadata.insert("version".to_string(), model.metadata.version.clone());
146
147 let safetensor_bytes = safetensors::serialize(views, Some(metadata))
149 .map_err(|e| Error::Serialization(format!("SafeTensors serialization failed: {e}")))?;
150
151 std::fs::write(path, safetensor_bytes)?;
153
154 Ok(())
155}
156
157fn save_apr(model: &Model, path: &Path) -> Result<()> {
162 use aprender::serialization::apr::AprWriter;
163 use serde_json::Value as JsonValue;
164
165 let mut writer = AprWriter::new();
166
167 writer.set_metadata("model_name", JsonValue::String(model.metadata.name.clone()));
169 writer.set_metadata("architecture", JsonValue::String(model.metadata.architecture.clone()));
170 writer.set_metadata("version", JsonValue::String(model.metadata.version.clone()));
171 writer.set_metadata("format", JsonValue::String("entrenar-checkpoint".into()));
172
173 let shapes = infer_all_tensor_shapes(&model.parameters);
175
176 for (name, tensor) in &model.parameters {
177 let data = tensor.data();
178 let slice = data.as_slice().expect("tensor data must be contiguous");
179 let shape = shapes.get(name).cloned().unwrap_or_else(|| vec![tensor.len()]);
180 writer.add_tensor_f32(name, shape, slice);
181 }
182
183 writer
184 .write(path)
185 .map_err(|e| Error::Serialization(format!("APR serialization failed: {e}")))?;
186
187 Ok(())
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::io::{Model, ModelMetadata};
194 use crate::Tensor;
195 use tempfile::NamedTempFile;
196
197 #[test]
198 fn test_save_model_json() {
199 let params = vec![
200 ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
201 ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
202 ];
203
204 let model = Model::new(ModelMetadata::new("test-model", "linear"), params);
205 let config = SaveConfig::new(ModelFormat::Json);
206
207 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
208 save_model(&model, temp_file.path(), &config).expect("save should succeed");
209
210 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
212 assert!(!content.is_empty());
213 assert!(content.contains("test-model"));
214 assert!(content.contains("linear"));
215 }
216
217 #[test]
218 fn test_save_model_yaml() {
219 let params = vec![("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true))];
220
221 let model = Model::new(ModelMetadata::new("test", "simple"), params);
222 let config = SaveConfig::new(ModelFormat::Yaml);
223
224 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
225 save_model(&model, temp_file.path(), &config).expect("save should succeed");
226
227 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
228 assert!(content.contains("test"));
229 assert!(content.contains("simple"));
230 }
231
232 #[test]
233 fn test_save_model_json_pretty() {
234 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
235 let model = Model::new(ModelMetadata::new("pretty-test", "test"), params);
236 let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
237
238 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
239 save_model(&model, temp_file.path(), &config).expect("save should succeed");
240
241 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
242 assert!(content.contains('\n'));
244 }
245
246 #[test]
247 fn test_save_model_json_compact() {
248 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
249 let model = Model::new(ModelMetadata::new("compact-test", "test"), params);
250 let config = SaveConfig::new(ModelFormat::Json).with_pretty(false);
251
252 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
253 save_model(&model, temp_file.path(), &config).expect("save should succeed");
254
255 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
256 let lines: Vec<&str> = content.lines().collect();
258 assert_eq!(lines.len(), 1);
259 }
260
261 #[test]
262 fn test_save_model_empty_params() {
263 let model = Model::new(ModelMetadata::new("empty", "test"), vec![]);
264 let config = SaveConfig::new(ModelFormat::Json);
265
266 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
267 save_model(&model, temp_file.path(), &config).expect("save should succeed");
268
269 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
270 assert!(content.contains("empty"));
271 }
272
273 #[test]
274 fn test_save_model_large_tensor() {
275 let large_data: Vec<f32> = (0..1000).map(|i| i as f32 * 0.001).collect();
276 let params = vec![("large".to_string(), Tensor::from_vec(large_data, false))];
277 let model = Model::new(ModelMetadata::new("large", "test"), params);
278 let config = SaveConfig::new(ModelFormat::Json);
279
280 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
281 save_model(&model, temp_file.path(), &config).expect("save should succeed");
282
283 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
284 assert!(content.len() > 1000);
285 }
286
287 #[test]
288 fn test_save_config_builder() {
289 let config = SaveConfig::new(ModelFormat::Json).with_pretty(true);
290 assert!(config.pretty);
291 assert_eq!(config.format, ModelFormat::Json);
292 }
293
294 #[test]
295 fn test_save_model_with_compress_option() {
296 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
297 let model = Model::new(ModelMetadata::new("compress-test", "test"), params);
298 let config = SaveConfig::new(ModelFormat::Json).with_compress(true);
299
300 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
301 save_model(&model, temp_file.path(), &config).expect("save should succeed");
303
304 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
305 assert!(content.contains("compress-test"));
306 }
307
308 #[test]
309 fn test_save_model_multiple_tensors() {
310 let params = vec![
311 ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
312 ("layer1.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
313 ("layer2.weight".to_string(), Tensor::from_vec(vec![3.0, 4.0], false)),
314 ];
315 let model = Model::new(ModelMetadata::new("multi", "deep"), params);
316 let config = SaveConfig::new(ModelFormat::Yaml);
317
318 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
319 save_model(&model, temp_file.path(), &config).expect("save should succeed");
320
321 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
322 assert!(content.contains("layer1.weight"));
323 assert!(content.contains("layer2.weight"));
324 }
325
326 #[test]
327 fn test_save_model_with_metadata() {
328 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
329 let meta = ModelMetadata::new("meta-test", "test")
330 .with_custom("version", serde_json::json!("1.0.0"))
331 .with_custom("author", serde_json::json!("test"));
332 let model = Model::new(meta, params);
333 let config = SaveConfig::new(ModelFormat::Json);
334
335 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
336 save_model(&model, temp_file.path(), &config).expect("save should succeed");
337
338 let content = std::fs::read_to_string(temp_file.path()).expect("file read should succeed");
339 assert!(content.contains("version"));
340 }
341
342 #[test]
343 fn test_save_config_default() {
344 let config = SaveConfig::default();
345 assert_eq!(config.format, ModelFormat::Json);
346 assert!(config.pretty);
347 assert!(!config.compress);
348 }
349
350 #[test]
351 fn test_save_model_invalid_path() {
352 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
353 let model = Model::new(ModelMetadata::new("test", "test"), params);
354 let config = SaveConfig::new(ModelFormat::Json);
355
356 let result = save_model(&model, "/nonexistent/directory/model.json", &config);
358 assert!(result.is_err());
359 }
360
361 #[test]
362 fn test_save_model_safetensors() {
363 let params = vec![
364 ("weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0], true)),
365 ("bias".to_string(), Tensor::from_vec(vec![0.1], false)),
366 ];
367
368 let model = Model::new(ModelMetadata::new("safetensor-test", "linear"), params);
369 let config = SaveConfig::new(ModelFormat::SafeTensors);
370
371 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
372 save_model(&model, temp_file.path(), &config).expect("save should succeed");
373
374 let content = std::fs::read(temp_file.path()).expect("file read should succeed");
376 assert!(!content.is_empty());
377 assert!(content.len() > 8);
379 }
380
381 #[test]
382 fn test_save_model_safetensors_can_be_loaded() {
383 let params = vec![
384 ("layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true)),
385 ("layer1.bias".to_string(), Tensor::from_vec(vec![0.5], false)),
386 ];
387
388 let model = Model::new(ModelMetadata::new("roundtrip-test", "mlp"), params);
389 let config = SaveConfig::new(ModelFormat::SafeTensors);
390
391 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
392 save_model(&model, temp_file.path(), &config).expect("save should succeed");
393
394 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
396 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
397
398 let names = loaded.names();
400 assert!(names.contains(&"layer1.weight"));
401 assert!(names.contains(&"layer1.bias"));
402
403 let weight = loaded.tensor("layer1.weight").expect("load should succeed");
405 assert_eq!(weight.shape(), &[4]);
406 let weight_data: &[f32] = bytemuck::cast_slice(weight.data());
407 assert_eq!(weight_data, &[1.0, 2.0, 3.0, 4.0]);
408 }
409
410 #[test]
411 fn test_save_safetensors_metadata() {
412 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
413 let model = Model::new(ModelMetadata::new("meta-model", "transformer"), params);
414 let config = SaveConfig::new(ModelFormat::SafeTensors);
415
416 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
417 save_model(&model, temp_file.path(), &config).expect("save should succeed");
418
419 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
421 let (_, st_metadata) =
422 safetensors::SafeTensors::read_metadata(&data).expect("deserialization should succeed");
423
424 let metadata = st_metadata.metadata();
425 assert!(metadata.is_some());
426 let meta = metadata.as_ref().expect("operation should succeed");
427 assert_eq!(meta.get("name").expect("key should exist"), "meta-model");
428 assert_eq!(meta.get("architecture").expect("key should exist"), "transformer");
429 }
430
431 #[test]
432 fn test_save_safetensors_large_tensor() {
433 let large_data: Vec<f32> = (0..10000).map(|i| i as f32 * 0.001).collect();
434 let params =
435 vec![("large_weights".to_string(), Tensor::from_vec(large_data.clone(), false))];
436 let model = Model::new(ModelMetadata::new("large", "test"), params);
437 let config = SaveConfig::new(ModelFormat::SafeTensors);
438
439 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
440 save_model(&model, temp_file.path(), &config).expect("save should succeed");
441
442 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
444 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
445 let tensor = loaded.tensor("large_weights").expect("load should succeed");
446 let tensor_data: &[f32] = bytemuck::cast_slice(tensor.data());
447 assert_eq!(tensor_data.len(), 10000);
448 assert!((tensor_data[0] - 0.0).abs() < 1e-6);
449 assert!((tensor_data[9999] - 9.999).abs() < 1e-3);
450 }
451
452 #[test]
453 fn test_save_safetensors_invalid_path() {
454 let params = vec![("w".to_string(), Tensor::from_vec(vec![1.0], false))];
455 let model = Model::new(ModelMetadata::new("test", "test"), params);
456 let config = SaveConfig::new(ModelFormat::SafeTensors);
457
458 let result = save_model(&model, "/nonexistent/directory/model.safetensors", &config);
459 assert!(result.is_err());
460 }
461
462 #[test]
463 fn test_save_safetensors_empty_params() {
464 let model = Model::new(ModelMetadata::new("empty", "test"), vec![]);
465 let config = SaveConfig::new(ModelFormat::SafeTensors);
466
467 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
468 save_model(&model, temp_file.path(), &config).expect("save should succeed");
469
470 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
472 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
473 assert_eq!(loaded.len(), 0);
474 }
475
476 #[test]
477 fn test_save_safetensors_multiple_tensors() {
478 let params = vec![
479 ("encoder.layer1.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0], true)),
480 ("encoder.layer1.bias".to_string(), Tensor::from_vec(vec![0.1], true)),
481 ("encoder.layer2.weight".to_string(), Tensor::from_vec(vec![3.0, 4.0, 5.0], false)),
482 ("decoder.layer1.weight".to_string(), Tensor::from_vec(vec![6.0, 7.0], false)),
483 ];
484 let model = Model::new(ModelMetadata::new("encoder-decoder", "transformer"), params);
485 let config = SaveConfig::new(ModelFormat::SafeTensors);
486
487 let temp_file = NamedTempFile::new().expect("temp file creation should succeed");
488 save_model(&model, temp_file.path(), &config).expect("save should succeed");
489
490 let data = std::fs::read(temp_file.path()).expect("file read should succeed");
491 let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
492 assert_eq!(loaded.len(), 4);
493
494 let names = loaded.names();
496 assert!(names.contains(&"encoder.layer1.weight"));
497 assert!(names.contains(&"decoder.layer1.weight"));
498 }
499
500 #[test]
502 fn test_safetensors_saves_2d_shapes() {
503 let hidden = 64;
504 let intermediate = 128;
505 let vocab = 256;
506
507 let params = vec![
508 ("model.embed_tokens.weight".to_string(), Tensor::zeros(vocab * hidden, false)),
509 ("model.norm.weight".to_string(), Tensor::zeros(hidden, false)),
510 ("model.layers.0.input_layernorm.weight".to_string(), Tensor::zeros(hidden, false)),
511 (
512 "model.layers.0.post_attention_layernorm.weight".to_string(),
513 Tensor::zeros(hidden, false),
514 ),
515 (
516 "model.layers.0.self_attn.q_proj.weight".to_string(),
517 Tensor::zeros(hidden * hidden, false),
518 ),
519 (
520 "model.layers.0.self_attn.k_proj.weight".to_string(),
521 Tensor::zeros(16 * hidden, false),
522 ),
523 (
524 "model.layers.0.self_attn.v_proj.weight".to_string(),
525 Tensor::zeros(16 * hidden, false),
526 ),
527 (
528 "model.layers.0.self_attn.o_proj.weight".to_string(),
529 Tensor::zeros(hidden * hidden, false),
530 ),
531 (
532 "model.layers.0.mlp.gate_proj.weight".to_string(),
533 Tensor::zeros(intermediate * hidden, false),
534 ),
535 (
536 "model.layers.0.mlp.up_proj.weight".to_string(),
537 Tensor::zeros(intermediate * hidden, false),
538 ),
539 (
540 "model.layers.0.mlp.down_proj.weight".to_string(),
541 Tensor::zeros(hidden * intermediate, false),
542 ),
543 ];
544
545 let metadata = ModelMetadata::new("test", "LlamaForCausalLM");
546 let model = Model::new(metadata, params);
547 let config =
548 crate::io::format::SaveConfig::new(crate::io::format::ModelFormat::SafeTensors);
549 let temp = NamedTempFile::new().unwrap();
550 save_model(&model, temp.path(), &config).unwrap();
551
552 let data = std::fs::read(temp.path()).unwrap();
553 let loaded = safetensors::SafeTensors::deserialize(&data).unwrap();
554
555 assert_eq!(loaded.tensor("model.norm.weight").unwrap().shape(), &[hidden]);
557 assert_eq!(
558 loaded.tensor("model.layers.0.input_layernorm.weight").unwrap().shape(),
559 &[hidden]
560 );
561
562 assert_eq!(loaded.tensor("model.embed_tokens.weight").unwrap().shape(), &[vocab, hidden]);
564 assert_eq!(
565 loaded.tensor("model.layers.0.self_attn.q_proj.weight").unwrap().shape(),
566 &[hidden, hidden]
567 );
568 assert_eq!(
569 loaded.tensor("model.layers.0.self_attn.k_proj.weight").unwrap().shape(),
570 &[16, hidden]
571 );
572 assert_eq!(
573 loaded.tensor("model.layers.0.mlp.gate_proj.weight").unwrap().shape(),
574 &[intermediate, hidden]
575 );
576 assert_eq!(
577 loaded.tensor("model.layers.0.mlp.down_proj.weight").unwrap().shape(),
578 &[hidden, intermediate]
579 );
580 }
581}