1use serde_json::{json, Map, Value as JsonValue};
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::path::Path;
6
7use crate::types::{CheckpointMetadata, ConvertError, Result, TensorData};
8
9pub(crate) fn write_safetensors(
10 path: &Path,
11 tensors: &BTreeMap<String, TensorData>,
12 source_sha256: &str,
13) -> Result<()> {
14 let mut data = Vec::new();
15 let mut tensor_meta = Map::new();
16
17 for (name, tensor) in tensors {
18 let start = data.len();
19 data.extend_from_slice(&tensor.bytes);
20 let end = data.len();
21
22 tensor_meta.insert(
23 name.clone(),
24 json!({
25 "dtype": tensor.dtype.as_safetensors(),
26 "shape": tensor.shape,
27 "data_offsets": [start, end],
28 }),
29 );
30 }
31
32 tensor_meta.insert(
33 "__metadata__".to_string(),
34 json!({
35 "format_version": "1",
36 "source_sha256": source_sha256,
37 "converter": "pt-loader",
38 }),
39 );
40
41 let header_bytes = serde_json::to_vec(&JsonValue::Object(tensor_meta))?;
42 let mut out = File::create(path)?;
43 out.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
44 out.write_all(&header_bytes)?;
45 out.write_all(&data)?;
46 Ok(())
47}
48
49pub fn write_metadata_yaml(path: &Path, metadata: &CheckpointMetadata) -> Result<()> {
50 let raw = serde_yaml::to_string(metadata)
51 .map_err(|error| ConvertError::InvalidStructure(format!("metadata yaml encode failed: {}", error)))?;
52 let body = inline_known_int_vec_fields_in_tensors(&raw, &["shape", "stride"]);
53 std::fs::write(path, body)?;
54 Ok(())
55}
56
57pub fn inline_known_int_vec_fields_in_tensors(raw: &str, fields: &[&str]) -> String {
58 let lines: Vec<&str> = raw.lines().collect();
59 let mut out: Vec<String> = Vec::with_capacity(lines.len());
60 let mut idx = 0usize;
61 let mut in_tensors_section = false;
62
63 while idx < lines.len() {
64 let line = lines[idx];
65 let indent = line.chars().take_while(|ch| ch.is_whitespace()).count();
66 let trimmed = line.trim();
67
68 if indent == 0 && trimmed == "tensors:" {
69 in_tensors_section = true;
70 out.push(line.to_string());
71 idx += 1;
72 continue;
73 }
74
75 if in_tensors_section
76 && indent == 0
77 && !trimmed.is_empty()
78 && !trimmed.starts_with('-')
79 && trimmed.ends_with(':')
80 && trimmed != "tensors:"
81 {
82 in_tensors_section = false;
83 }
84
85 if in_tensors_section {
86 let mut replaced = false;
87 for field in fields {
88 if trimmed == format!("{field}:") {
89 let mut values: Vec<String> = Vec::new();
90 let mut probe = idx + 1;
91 while probe < lines.len() {
92 let next = lines[probe];
93 let next_indent = next.chars().take_while(|ch| ch.is_whitespace()).count();
94 if next_indent < indent {
95 break;
96 }
97 let next_trimmed = next.trim();
98 let Some(value) = next_trimmed.strip_prefix("- ") else {
99 break;
100 };
101 values.push(value.trim().to_string());
102 probe += 1;
103 }
104
105 if !values.is_empty() {
106 out.push(format!(
107 "{}{}: {}",
108 " ".repeat(indent),
109 field,
110 format_inline_int_vec(&values)
111 ));
112 idx = probe;
113 replaced = true;
114 break;
115 }
116 }
117 }
118 if replaced {
119 continue;
120 }
121 }
122
123 out.push(line.to_string());
124 idx += 1;
125 }
126
127 if raw.ends_with('\n') {
128 out.join("\n") + "\n"
129 } else {
130 out.join("\n")
131 }
132}
133
134fn format_inline_int_vec(values: &[String]) -> String {
135 let mut out = String::from("[");
136 for (idx, value) in values.iter().enumerate() {
137 if idx > 0 {
138 out.push_str(", ");
139 }
140 out.push_str(value);
141 }
142 out.push(']');
143 out
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[test]
151 fn formats_inline_int_vec() {
152 let rendered = format_inline_int_vec(&["10".to_string(), "20".to_string(), "30".to_string()]);
153 assert_eq!(rendered, "[10, 20, 30]");
154 }
155
156 #[test]
157 fn inlines_only_known_fields_in_tensors_section() {
158 let raw = r#"format_version: 1
159metadata:
160 shape:
161 - 1
162 - 2
163tensors:
164- name: x
165 shape:
166 - 10
167 - 20
168 - 30
169 dtype: F32
170"#;
171
172 let rendered = inline_known_int_vec_fields_in_tensors(raw, &["shape", "stride"]);
173 assert!(rendered.contains("metadata:\n shape:\n - 1\n - 2"));
174 assert!(rendered.contains(" shape: [10, 20, 30]"));
175 }
176}