Skip to main content

pt_loader/
writer.rs

1use serde_json::{json, Map, Value as JsonValue};
2use std::collections::BTreeMap;
3use std::fs::File;
4use std::io::Write;
5use std::path::Path;
6
7use crate::types::{CheckpointMetadata, ConvertError, Result, TensorData};
8
9pub(crate) fn write_safetensors(
10  path: &Path,
11  tensors: &BTreeMap<String, TensorData>,
12  source_sha256: &str,
13) -> Result<()> {
14  let mut data = Vec::new();
15  let mut tensor_meta = Map::new();
16
17  for (name, tensor) in tensors {
18    let start = data.len();
19    data.extend_from_slice(&tensor.bytes);
20    let end = data.len();
21
22    tensor_meta.insert(
23      name.clone(),
24      json!({
25        "dtype": tensor.dtype.as_safetensors(),
26        "shape": tensor.shape,
27        "data_offsets": [start, end],
28      }),
29    );
30  }
31
32  tensor_meta.insert(
33    "__metadata__".to_string(),
34    json!({
35      "format_version": "1",
36      "source_sha256": source_sha256,
37      "converter": "pt-loader",
38    }),
39  );
40
41  let header_bytes = serde_json::to_vec(&JsonValue::Object(tensor_meta))?;
42  let mut out = File::create(path)?;
43  out.write_all(&(header_bytes.len() as u64).to_le_bytes())?;
44  out.write_all(&header_bytes)?;
45  out.write_all(&data)?;
46  Ok(())
47}
48
49pub fn write_metadata_yaml(path: &Path, metadata: &CheckpointMetadata) -> Result<()> {
50  let raw = serde_yaml::to_string(metadata)
51    .map_err(|error| ConvertError::InvalidStructure(format!("metadata yaml encode failed: {}", error)))?;
52  let body = inline_known_int_vec_fields_in_tensors(&raw, &["shape", "stride"]);
53  std::fs::write(path, body)?;
54  Ok(())
55}
56
57pub fn inline_known_int_vec_fields_in_tensors(raw: &str, fields: &[&str]) -> String {
58  let lines: Vec<&str> = raw.lines().collect();
59  let mut out: Vec<String> = Vec::with_capacity(lines.len());
60  let mut idx = 0usize;
61  let mut in_tensors_section = false;
62
63  while idx < lines.len() {
64    let line = lines[idx];
65    let indent = line.chars().take_while(|ch| ch.is_whitespace()).count();
66    let trimmed = line.trim();
67
68    if indent == 0 && trimmed == "tensors:" {
69      in_tensors_section = true;
70      out.push(line.to_string());
71      idx += 1;
72      continue;
73    }
74
75    if in_tensors_section
76      && indent == 0
77      && !trimmed.is_empty()
78      && !trimmed.starts_with('-')
79      && trimmed.ends_with(':')
80      && trimmed != "tensors:"
81    {
82      in_tensors_section = false;
83    }
84
85    if in_tensors_section {
86      let mut replaced = false;
87      for field in fields {
88        if trimmed == format!("{field}:") {
89          let mut values: Vec<String> = Vec::new();
90          let mut probe = idx + 1;
91          while probe < lines.len() {
92            let next = lines[probe];
93            let next_indent = next.chars().take_while(|ch| ch.is_whitespace()).count();
94            if next_indent < indent {
95              break;
96            }
97            let next_trimmed = next.trim();
98            let Some(value) = next_trimmed.strip_prefix("- ") else {
99              break;
100            };
101            values.push(value.trim().to_string());
102            probe += 1;
103          }
104
105          if !values.is_empty() {
106            out.push(format!(
107              "{}{}: {}",
108              " ".repeat(indent),
109              field,
110              format_inline_int_vec(&values)
111            ));
112            idx = probe;
113            replaced = true;
114            break;
115          }
116        }
117      }
118      if replaced {
119        continue;
120      }
121    }
122
123    out.push(line.to_string());
124    idx += 1;
125  }
126
127  if raw.ends_with('\n') {
128    out.join("\n") + "\n"
129  } else {
130    out.join("\n")
131  }
132}
133
134fn format_inline_int_vec(values: &[String]) -> String {
135  let mut out = String::from("[");
136  for (idx, value) in values.iter().enumerate() {
137    if idx > 0 {
138      out.push_str(", ");
139    }
140    out.push_str(value);
141  }
142  out.push(']');
143  out
144}
145
146#[cfg(test)]
147mod tests {
148  use super::*;
149
150  #[test]
151  fn formats_inline_int_vec() {
152    let rendered = format_inline_int_vec(&["10".to_string(), "20".to_string(), "30".to_string()]);
153    assert_eq!(rendered, "[10, 20, 30]");
154  }
155
156  #[test]
157  fn inlines_only_known_fields_in_tensors_section() {
158    let raw = r#"format_version: 1
159metadata:
160  shape:
161  - 1
162  - 2
163tensors:
164- name: x
165  shape:
166  - 10
167  - 20
168  - 30
169  dtype: F32
170"#;
171
172    let rendered = inline_known_int_vec_fields_in_tensors(raw, &["shape", "stride"]);
173    assert!(rendered.contains("metadata:\n  shape:\n  - 1\n  - 2"));
174    assert!(rendered.contains("  shape: [10, 20, 30]"));
175  }
176}