1use std::collections::HashMap;
29use std::fs::File;
30use std::io::{BufReader, BufWriter, Read, Write};
31use std::path::Path;
32
33pub type CheckpointMetadata = HashMap<String, String>;
39
40#[derive(Debug)]
42pub struct Checkpoint {
43 pub version: u32,
45 pub metadata: CheckpointMetadata,
47 pub tensors: Vec<CheckpointTensor>,
49}
50
51#[derive(Debug, Clone)]
53pub struct CheckpointTensor {
54 pub name: String,
56 pub shape: Vec<u64>,
58 pub data: Vec<f32>,
60}
61
62impl CheckpointTensor {
67 pub fn new(name: impl Into<String>, data: Vec<f32>, shape: Vec<u64>) -> Self {
69 Self {
70 name: name.into(),
71 shape,
72 data,
73 }
74 }
75
76 pub fn element_count(&self) -> u64 {
78 if self.shape.is_empty() {
79 return 0;
80 }
81 self.shape.iter().product()
82 }
83
84 pub fn size_bytes(&self) -> usize {
86 self.element_count() as usize * 4
87 }
88
89 pub fn from_weight_tensor(wt: &crate::model_merge::WeightTensor) -> Self {
93 Self {
94 name: wt.name.clone(),
95 shape: wt.shape.iter().map(|&d| d as u64).collect(),
96 data: wt.data.clone(),
97 }
98 }
99
100 pub fn to_weight_tensor(&self) -> crate::model_merge::WeightTensor {
106 let shape: Vec<usize> = self
107 .shape
108 .iter()
109 .map(|&d| usize::try_from(d).unwrap_or(usize::MAX))
110 .collect();
111 crate::model_merge::WeightTensor::new(self.name.clone(), self.data.clone(), shape)
112 }
113}
114
115impl Checkpoint {
120 pub fn new() -> Self {
122 Self {
123 version: 1,
124 metadata: CheckpointMetadata::new(),
125 tensors: Vec::new(),
126 }
127 }
128
129 pub fn add_tensor(&mut self, tensor: CheckpointTensor) {
131 self.tensors.push(tensor);
132 }
133
134 pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
136 self.metadata.insert(key.into(), value.into());
137 }
138
139 pub fn get_metadata(&self, key: &str) -> Option<&str> {
141 self.metadata.get(key).map(|s| s.as_str())
142 }
143
144 pub fn get_tensor(&self, name: &str) -> Option<&CheckpointTensor> {
146 self.tensors.iter().find(|t| t.name == name)
147 }
148
149 pub fn total_bytes(&self) -> usize {
151 self.tensors.iter().map(|t| t.size_bytes()).sum()
152 }
153
154 pub fn num_params(&self) -> u64 {
156 self.tensors.iter().map(|t| t.element_count()).sum()
157 }
158
159 pub fn save(&self, path: &Path) -> Result<(), CheckpointError> {
163 let file = File::create(path)?;
164 let mut writer = BufWriter::new(file);
165 self.write_to(&mut writer)
166 }
167
168 pub fn load(path: &Path) -> Result<Self, CheckpointError> {
170 let file = File::open(path)?;
171 let mut reader = BufReader::new(file);
172 Self::read_from(&mut reader)
173 }
174
175 pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<(), CheckpointError> {
183 writer.write_all(b"OXCK")?;
185 write_u32_le(writer, 1u32)?; write_u64_le(writer, 0u64)?; write_u64_le(writer, self.tensors.len() as u64)?;
188
189 let meta_str = serialize_metadata(&self.metadata);
191 let meta_bytes = meta_str.as_bytes();
192 write_u32_le(writer, meta_bytes.len() as u32)?;
193 writer.write_all(meta_bytes)?;
194
195 for tensor in &self.tensors {
197 let name_bytes = tensor.name.as_bytes();
198 if name_bytes.len() > 65535 {
199 return Err(CheckpointError::NameTooLong(name_bytes.len()));
200 }
201 write_u32_le(writer, name_bytes.len() as u32)?;
202 writer.write_all(name_bytes)?;
203
204 write_u32_le(writer, tensor.shape.len() as u32)?;
205 for &dim in &tensor.shape {
206 write_u64_le(writer, dim)?;
207 }
208
209 write_u64_le(writer, tensor.data.len() as u64)?;
210 for &f in &tensor.data {
211 writer.write_all(&f.to_le_bytes())?;
212 }
213 }
214
215 Ok(())
216 }
217
218 pub fn read_from<R: Read>(reader: &mut R) -> Result<Self, CheckpointError> {
220 let mut magic = [0u8; 4];
222 read_exact(reader, &mut magic)?;
223 if &magic != b"OXCK" {
224 return Err(CheckpointError::InvalidMagic(magic.to_vec()));
225 }
226
227 let version = read_u32_le(reader)?;
229 if version == 0 || version > 1 {
230 return Err(CheckpointError::UnsupportedVersion(version));
231 }
232
233 let _flags = read_u64_le(reader)?;
235
236 let num_tensors = read_u64_le(reader)? as usize;
238
239 let meta_len = read_u32_le(reader)? as usize;
241 let mut meta_bytes = vec![0u8; meta_len];
242 read_exact(reader, &mut meta_bytes)?;
243 let meta_str = std::str::from_utf8(&meta_bytes)
244 .map_err(|e| CheckpointError::MetadataParse(e.to_string()))?;
245 let metadata = deserialize_metadata(meta_str)?;
246
247 let mut tensors = Vec::with_capacity(num_tensors);
249 for _ in 0..num_tensors {
250 let name_len = read_u32_le(reader)? as usize;
252 let mut name_bytes = vec![0u8; name_len];
253 read_exact(reader, &mut name_bytes)?;
254 let name = String::from_utf8(name_bytes)
255 .map_err(|e| CheckpointError::MetadataParse(e.to_string()))?;
256
257 let ndim = read_u32_le(reader)? as usize;
259 let mut shape = Vec::with_capacity(ndim);
260 for _ in 0..ndim {
261 shape.push(read_u64_le(reader)?);
262 }
263
264 let data_len = read_u64_le(reader)? as usize;
266 let mut data = Vec::with_capacity(data_len);
267 for _ in 0..data_len {
268 let mut buf = [0u8; 4];
269 read_exact(reader, &mut buf)?;
270 data.push(f32::from_le_bytes(buf));
271 }
272
273 tensors.push(CheckpointTensor { name, shape, data });
274 }
275
276 Ok(Self {
277 version,
278 metadata,
279 tensors,
280 })
281 }
282}
283
284impl Default for Checkpoint {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290fn serialize_metadata(meta: &CheckpointMetadata) -> String {
300 let mut pairs: Vec<(&String, &String)> = meta.iter().collect();
302 pairs.sort_by_key(|(k, _)| k.as_str());
303
304 let mut out = String::from('{');
305 for (i, (k, v)) in pairs.iter().enumerate() {
306 if i > 0 {
307 out.push(',');
308 }
309 out.push('"');
310 push_escaped(&mut out, k);
311 out.push_str("\":\"");
312 push_escaped(&mut out, v);
313 out.push('"');
314 }
315 out.push('}');
316 out
317}
318
319fn push_escaped(out: &mut String, s: &str) {
321 for ch in s.chars() {
322 match ch {
323 '"' => out.push_str("\\\""),
324 '\\' => out.push_str("\\\\"),
325 other => out.push(other),
326 }
327 }
328}
329
330fn deserialize_metadata(s: &str) -> Result<CheckpointMetadata, CheckpointError> {
336 let s = s.trim();
337 if s.is_empty() {
338 return Ok(CheckpointMetadata::new());
339 }
340
341 if s == "{}" {
343 return Ok(CheckpointMetadata::new());
344 }
345
346 let bytes = s.as_bytes();
347 if bytes.first() != Some(&b'{') || bytes.last() != Some(&b'}') {
348 return Err(CheckpointError::MetadataParse(format!(
349 "expected JSON object, got: {s}"
350 )));
351 }
352
353 let inner = &s[1..s.len() - 1];
355 let mut map = CheckpointMetadata::new();
356
357 if inner.trim().is_empty() {
358 return Ok(map);
359 }
360
361 let chars: Vec<char> = inner.chars().collect();
364 let mut pos = 0usize;
365
366 loop {
367 while pos < chars.len() && (chars[pos] == ',' || chars[pos].is_whitespace()) {
369 pos += 1;
370 }
371 if pos >= chars.len() {
372 break;
373 }
374
375 if chars[pos] != '"' {
377 return Err(CheckpointError::MetadataParse(format!(
378 "expected '\"' at position {pos}, got '{}'",
379 chars[pos]
380 )));
381 }
382 pos += 1;
383
384 let (key, new_pos) = parse_json_string(&chars, pos)?;
385 pos = new_pos;
386
387 skip_ws(&chars, &mut pos);
389 if pos >= chars.len() || chars[pos] != ':' {
390 return Err(CheckpointError::MetadataParse(format!(
391 "expected ':' after key '{key}'"
392 )));
393 }
394 pos += 1;
395 skip_ws(&chars, &mut pos);
396
397 if pos >= chars.len() || chars[pos] != '"' {
399 return Err(CheckpointError::MetadataParse(format!(
400 "expected '\"' for value of key '{key}'"
401 )));
402 }
403 pos += 1;
404
405 let (value, new_pos) = parse_json_string(&chars, pos)?;
406 pos = new_pos;
407
408 map.insert(key, value);
409 }
410
411 Ok(map)
412}
413
414fn parse_json_string(chars: &[char], mut pos: usize) -> Result<(String, usize), CheckpointError> {
418 let mut s = String::new();
419 while pos < chars.len() {
420 match chars[pos] {
421 '"' => {
422 pos += 1; return Ok((s, pos));
424 }
425 '\\' => {
426 pos += 1;
427 if pos >= chars.len() {
428 return Err(CheckpointError::MetadataParse(
429 "unexpected end after backslash".into(),
430 ));
431 }
432 match chars[pos] {
433 '"' => s.push('"'),
434 '\\' => s.push('\\'),
435 'n' => s.push('\n'),
436 'r' => s.push('\r'),
437 't' => s.push('\t'),
438 other => {
439 return Err(CheckpointError::MetadataParse(format!(
440 "unknown escape '\\{other}'"
441 )))
442 }
443 }
444 pos += 1;
445 }
446 ch => {
447 s.push(ch);
448 pos += 1;
449 }
450 }
451 }
452 Err(CheckpointError::MetadataParse("unterminated string".into()))
453}
454
455fn skip_ws(chars: &[char], pos: &mut usize) {
457 while *pos < chars.len() && chars[*pos].is_whitespace() {
458 *pos += 1;
459 }
460}
461
462fn write_u32_le<W: Write>(w: &mut W, v: u32) -> Result<(), CheckpointError> {
467 w.write_all(&v.to_le_bytes())?;
468 Ok(())
469}
470
471fn write_u64_le<W: Write>(w: &mut W, v: u64) -> Result<(), CheckpointError> {
472 w.write_all(&v.to_le_bytes())?;
473 Ok(())
474}
475
476fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<(), CheckpointError> {
477 let expected = buf.len();
478 let mut total_read = 0usize;
479 while total_read < expected {
480 match r.read(&mut buf[total_read..]) {
481 Ok(0) => {
482 return Err(CheckpointError::TruncatedData {
483 expected,
484 got: total_read,
485 })
486 }
487 Ok(n) => total_read += n,
488 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
489 Err(e) => return Err(CheckpointError::Io(e)),
490 }
491 }
492 Ok(())
493}
494
495fn read_u32_le<R: Read>(r: &mut R) -> Result<u32, CheckpointError> {
496 let mut buf = [0u8; 4];
497 read_exact(r, &mut buf)?;
498 Ok(u32::from_le_bytes(buf))
499}
500
501fn read_u64_le<R: Read>(r: &mut R) -> Result<u64, CheckpointError> {
502 let mut buf = [0u8; 8];
503 read_exact(r, &mut buf)?;
504 Ok(u64::from_le_bytes(buf))
505}
506
507#[derive(Debug, thiserror::Error)]
513pub enum CheckpointError {
514 #[error("I/O error: {0}")]
516 Io(#[from] std::io::Error),
517
518 #[error("invalid magic bytes: expected OXCK, got {0:?}")]
520 InvalidMagic(Vec<u8>),
521
522 #[error("unsupported checkpoint version: {0}")]
524 UnsupportedVersion(u32),
525
526 #[error("metadata parse error: {0}")]
528 MetadataParse(String),
529
530 #[error("truncated data: expected {expected} bytes, got {got}")]
532 TruncatedData { expected: usize, got: usize },
533
534 #[error("tensor name too long: {0} bytes (max 65535)")]
536 NameTooLong(usize),
537}