1use std::collections::HashMap;
6use std::fs::{self, File, OpenOptions};
7use std::io::{self, BufWriter, Write};
8use std::path::{Path, PathBuf};
9
10const ZSTD_MAGIC: [u8; 4] = [0x28, 0xB5, 0x2F, 0xFD];
12
13#[allow(unused_imports)]
14use super::functions::*;
15use super::functions::{
16 FORMAT_VERSION, MAGIC, TAG_FOOTER, TAG_INTEGERS, TAG_POSITIONS, TAG_SCALARS, TAG_VELOCITIES,
17};
18
19#[derive(Debug, Clone, PartialEq)]
24pub struct Checkpoint {
25 pub version: u32,
27 pub timestamp: u64,
29 pub step: u64,
31 pub sim_time: f64,
33 pub state: Vec<u8>,
35 pub checksum: u32,
37}
38impl Checkpoint {
39 pub fn new(version: u32, timestamp: u64, step: u64, sim_time: f64, state: Vec<u8>) -> Self {
41 let checksum = compute_checksum(&state);
42 Self {
43 version,
44 timestamp,
45 step,
46 sim_time,
47 state,
48 checksum,
49 }
50 }
51 pub fn compute_checksum(&mut self) {
53 self.checksum = compute_checksum(&self.state);
54 }
55 pub fn verify(&self) -> bool {
57 compute_checksum(&self.state) == self.checksum
58 }
59 pub fn to_bytes(&self) -> Vec<u8> {
70 let mut buf = Vec::new();
71 buf.extend_from_slice(&self.version.to_le_bytes());
72 buf.extend_from_slice(&self.timestamp.to_le_bytes());
73 buf.extend_from_slice(&self.step.to_le_bytes());
74 buf.extend_from_slice(&self.sim_time.to_bits().to_le_bytes());
75 buf.extend_from_slice(&self.checksum.to_le_bytes());
76 buf.extend_from_slice(&(self.state.len() as u64).to_le_bytes());
77 buf.extend_from_slice(&self.state);
78 buf
79 }
80 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
82 let mut cursor = 0usize;
83 let version = read_u32(data, &mut cursor)?;
84 let timestamp = read_u64(data, &mut cursor)?;
85 let step = read_u64(data, &mut cursor)?;
86 let sim_time = read_f64(data, &mut cursor)?;
87 let checksum = read_u32(data, &mut cursor)?;
88 let state_len = read_u64(data, &mut cursor)? as usize;
89 if cursor + state_len > data.len() {
90 return Err(io::Error::new(
91 io::ErrorKind::UnexpectedEof,
92 "state truncated",
93 ));
94 }
95 let state = data[cursor..cursor + state_len].to_vec();
96 Ok(Self {
97 version,
98 timestamp,
99 step,
100 sim_time,
101 state,
102 checksum,
103 })
104 }
105}
106#[allow(dead_code)]
108#[derive(Debug)]
109pub struct CheckpointInspector {
110 pub base_dir: PathBuf,
112}
113impl CheckpointInspector {
114 pub fn new(base_dir: impl Into<PathBuf>) -> Self {
116 Self {
117 base_dir: base_dir.into(),
118 }
119 }
120 pub fn list(&self) -> Vec<PathBuf> {
122 let mut paths: Vec<PathBuf> = match fs::read_dir(&self.base_dir) {
123 Ok(rd) => rd
124 .filter_map(|e| e.ok())
125 .map(|e| e.path())
126 .filter(|p| p.extension().is_some_and(|ext| ext == "bin"))
127 .collect(),
128 Err(_) => Vec::new(),
129 };
130 paths.sort();
131 paths
132 }
133 pub fn peek_header(&self, path: &Path) -> io::Result<Checkpoint> {
135 let data = fs::read(path)?;
136 Checkpoint::from_bytes(&data)
137 }
138 pub fn metadata_summary(&self) -> Vec<(u64, f64)> {
140 self.list()
141 .iter()
142 .filter_map(|p| self.peek_header(p).ok())
143 .map(|c| (c.step, c.sim_time))
144 .collect()
145 }
146 pub fn count(&self) -> usize {
148 self.list().len()
149 }
150}
151#[derive(Debug, Clone)]
154pub struct RestartFile {
155 pub meta: CheckpointMetadata,
157 pub positions: Vec<[f64; 3]>,
159 pub velocities: Vec<[f64; 3]>,
161 pub scalars: HashMap<String, Vec<f64>>,
163}
164impl RestartFile {
165 pub fn new(
167 meta: CheckpointMetadata,
168 positions: Vec<[f64; 3]>,
169 velocities: Vec<[f64; 3]>,
170 scalars: HashMap<String, Vec<f64>>,
171 ) -> Self {
172 Self {
173 meta,
174 positions,
175 velocities,
176 scalars,
177 }
178 }
179 pub fn save(&self, path: &Path) -> io::Result<()> {
181 let writer = CheckpointWriter::new(path);
182 writer.write_header(&self.meta)?;
183 writer.write_positions(&self.positions)?;
184 writer.write_velocities(&self.velocities)?;
185 let mut keys: Vec<&String> = self.scalars.keys().collect();
186 keys.sort();
187 for k in keys {
188 writer.write_scalars(k, &self.scalars[k])?;
189 }
190 writer.finalize()
191 }
192 pub fn load(path: &Path) -> io::Result<Self> {
194 let reader = CheckpointReader::new(path);
195 let meta = reader.read_metadata()?;
196 let positions = reader.read_positions()?;
197 let velocities = reader.read_velocities()?;
198 let data = fs::read(path)?;
199 let mut scalars: HashMap<String, Vec<f64>> = HashMap::new();
200 let mut cursor = {
201 let mut c = 0usize;
202 let _magic = read_u32(&data, &mut c)?;
203 let _version = read_u32(&data, &mut c)?;
204 let meta_len = read_u32(&data, &mut c)? as usize;
205 c += meta_len;
206 c
207 };
208 while cursor < data.len() {
209 let tag = data[cursor];
210 cursor += 1;
211 match tag {
212 TAG_POSITIONS | TAG_VELOCITIES => {
213 let count = read_u64(&data, &mut cursor)? as usize;
214 cursor += count * 24;
215 }
216 TAG_SCALARS => {
217 let name = read_name(&data, &mut cursor)?;
218 let count = read_u64(&data, &mut cursor)? as usize;
219 let mut vals = Vec::with_capacity(count);
220 for _ in 0..count {
221 vals.push(read_f64(&data, &mut cursor)?);
222 }
223 scalars.insert(name, vals);
224 }
225 TAG_INTEGERS => {
226 let _name = read_name(&data, &mut cursor)?;
227 let count = read_u64(&data, &mut cursor)? as usize;
228 cursor += count * 4;
229 }
230 TAG_FOOTER => break,
231 _ => {
232 return Err(io::Error::new(
233 io::ErrorKind::InvalidData,
234 format!("unknown tag 0x{tag:02X} while loading restart"),
235 ));
236 }
237 }
238 }
239 Ok(Self {
240 meta,
241 positions,
242 velocities,
243 scalars,
244 })
245 }
246}
247#[allow(dead_code)]
249#[derive(Debug, Clone, Copy, PartialEq, Eq)]
250pub enum CheckpointFormat {
251 Binary,
253 Json,
255 Compressed,
257 HDF5Like,
259}
260impl CheckpointFormat {
261 pub fn extension(self) -> &'static str {
263 match self {
264 Self::Binary => "bin",
265 Self::Json => "json",
266 Self::Compressed => "rle",
267 Self::HDF5Like => "h5xt",
268 }
269 }
270 pub fn is_text(self) -> bool {
272 matches!(self, Self::Json)
273 }
274}
275#[allow(dead_code)]
277#[derive(Debug, Clone, PartialEq)]
278pub struct CheckpointHeader {
279 pub version: [u32; 3],
281 pub timestamp: u64,
283 pub step: u64,
285 pub crate_name: String,
287 pub checksum: u32,
289}
290impl CheckpointHeader {
291 pub fn new(
293 version: [u32; 3],
294 timestamp: u64,
295 step: u64,
296 crate_name: impl Into<String>,
297 checksum: u32,
298 ) -> Self {
299 Self {
300 version,
301 timestamp,
302 step,
303 crate_name: crate_name.into(),
304 checksum,
305 }
306 }
307 pub fn to_bytes(&self) -> Vec<u8> {
309 let mut buf = Vec::new();
310 for v in &self.version {
311 buf.extend_from_slice(&v.to_le_bytes());
312 }
313 buf.extend_from_slice(&self.timestamp.to_le_bytes());
314 buf.extend_from_slice(&self.step.to_le_bytes());
315 let name_bytes = self.crate_name.as_bytes();
316 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
317 buf.extend_from_slice(name_bytes);
318 buf.extend_from_slice(&self.checksum.to_le_bytes());
319 buf
320 }
321 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
323 let mut c = 0usize;
324 let v0 = read_u32(data, &mut c)?;
325 let v1 = read_u32(data, &mut c)?;
326 let v2 = read_u32(data, &mut c)?;
327 let timestamp = read_u64(data, &mut c)?;
328 let step = read_u64(data, &mut c)?;
329 let name_len = read_u32(data, &mut c)? as usize;
330 if c + name_len > data.len() {
331 return Err(io::Error::new(
332 io::ErrorKind::UnexpectedEof,
333 "header crate_name truncated",
334 ));
335 }
336 let crate_name = std::str::from_utf8(&data[c..c + name_len])
337 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
338 .to_owned();
339 c += name_len;
340 let checksum = read_u32(data, &mut c)?;
341 Ok(Self {
342 version: [v0, v1, v2],
343 timestamp,
344 step,
345 crate_name,
346 checksum,
347 })
348 }
349 pub fn version_compatible(&self, expected_major: u32) -> bool {
351 self.version[0] == expected_major
352 }
353}
354#[derive(Debug, Clone)]
357pub struct CheckpointFileWriter {
358 pub output_dir: PathBuf,
360}
361impl CheckpointFileWriter {
362 pub fn new(output_dir: impl Into<PathBuf>) -> Self {
364 Self {
365 output_dir: output_dir.into(),
366 }
367 }
368 pub fn write(&self, checkpoint: &Checkpoint) -> io::Result<PathBuf> {
371 let base_name = format!("checkpoint_{:010}", checkpoint.step);
372 let bin_path = self.output_dir.join(format!("{base_name}.bin"));
373 let json_path = self.output_dir.join(format!("{base_name}.json"));
374 fs::write(&bin_path, checkpoint.to_bytes())?;
375 let json = format!(
376 r#"{{"version":{},"timestamp":{},"step":{},"sim_time":{},"state_len":{},"checksum":{}}}"#,
377 checkpoint.version,
378 checkpoint.timestamp,
379 checkpoint.step,
380 checkpoint.sim_time,
381 checkpoint.state.len(),
382 checkpoint.checksum
383 );
384 fs::write(&json_path, json.as_bytes())?;
385 Ok(bin_path)
386 }
387}
388#[allow(dead_code)]
390#[derive(Debug, Clone)]
391pub struct SimulationState {
392 pub positions: Vec<[f64; 3]>,
394 pub velocities: Vec<[f64; 3]>,
396 pub forces: Vec<[f64; 3]>,
398 pub metadata: HashMap<String, f64>,
400}
401impl SimulationState {
402 pub fn new() -> Self {
404 Self {
405 positions: Vec::new(),
406 velocities: Vec::new(),
407 forces: Vec::new(),
408 metadata: HashMap::new(),
409 }
410 }
411 pub fn len(&self) -> usize {
413 self.positions.len()
414 }
415 pub fn is_empty(&self) -> bool {
417 self.positions.is_empty()
418 }
419 pub fn to_bytes(&self) -> Vec<u8> {
421 let mut buf = Vec::new();
422 let n = self.positions.len() as u64;
423 buf.extend_from_slice(&n.to_le_bytes());
424 for pos in &self.positions {
425 for &c in pos {
426 buf.extend_from_slice(&c.to_le_bytes());
427 }
428 }
429 for vel in &self.velocities {
430 for &c in vel {
431 buf.extend_from_slice(&c.to_le_bytes());
432 }
433 }
434 for frc in &self.forces {
435 for &c in frc {
436 buf.extend_from_slice(&c.to_le_bytes());
437 }
438 }
439 buf
440 }
441 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
443 let mut cur = 0usize;
444 let n = read_u64(data, &mut cur)? as usize;
445 let mut positions = Vec::with_capacity(n);
446 let mut velocities = Vec::with_capacity(n);
447 let mut forces = Vec::with_capacity(n);
448 for _ in 0..n {
449 let x = read_f64(data, &mut cur)?;
450 let y = read_f64(data, &mut cur)?;
451 let z = read_f64(data, &mut cur)?;
452 positions.push([x, y, z]);
453 }
454 for _ in 0..n {
455 let x = read_f64(data, &mut cur)?;
456 let y = read_f64(data, &mut cur)?;
457 let z = read_f64(data, &mut cur)?;
458 velocities.push([x, y, z]);
459 }
460 for _ in 0..n {
461 let x = read_f64(data, &mut cur)?;
462 let y = read_f64(data, &mut cur)?;
463 let z = read_f64(data, &mut cur)?;
464 forces.push([x, y, z]);
465 }
466 Ok(Self {
467 positions,
468 velocities,
469 forces,
470 metadata: HashMap::new(),
471 })
472 }
473}
474#[allow(dead_code)]
476#[derive(Debug, Clone)]
477pub enum RestartStrategy {
478 FromLatest,
480 FromStep(u64),
482 FromFile(String),
484}
485impl RestartStrategy {
486 pub fn resolve(&self, manager: &CheckpointManager) -> Option<PathBuf> {
488 match self {
489 Self::FromLatest => manager.latest_checkpoint(),
490 Self::FromStep(step) => {
491 let p = manager.checkpoint_path(*step);
492 if p.exists() { Some(p) } else { None }
493 }
494 Self::FromFile(path) => {
495 let p = PathBuf::from(path);
496 if p.exists() { Some(p) } else { None }
497 }
498 }
499 }
500 pub fn is_latest(&self) -> bool {
502 matches!(self, Self::FromLatest)
503 }
504}
505#[derive(Debug, Clone)]
515pub struct CheckpointWriter {
516 pub path: PathBuf,
518 pub compress: bool,
520}
521impl CheckpointWriter {
522 pub fn new(path: impl Into<PathBuf>) -> Self {
524 Self {
525 path: path.into(),
526 compress: false,
527 }
528 }
529 pub fn with_compress(mut self, compress: bool) -> Self {
531 self.compress = compress;
532 self
533 }
534 pub fn write_header(&self, meta: &CheckpointMetadata) -> io::Result<()> {
539 let mut f = BufWriter::new(File::create(&self.path)?);
540 f.write_all(&MAGIC.to_le_bytes())?;
541 f.write_all(&FORMAT_VERSION.to_le_bytes())?;
542 let meta_bytes = meta.to_bytes();
543 f.write_all(&(meta_bytes.len() as u32).to_le_bytes())?;
544 f.write_all(&meta_bytes)?;
545 f.flush()
546 }
547 pub fn write_positions(&self, pos: &[[f64; 3]]) -> io::Result<()> {
551 self.append_vec3_block(TAG_POSITIONS, pos)
552 }
553 pub fn write_velocities(&self, vel: &[[f64; 3]]) -> io::Result<()> {
555 self.append_vec3_block(TAG_VELOCITIES, vel)
556 }
557 pub fn write_scalars(&self, name: &str, data: &[f64]) -> io::Result<()> {
559 let mut f = self.open_append()?;
560 f.write_all(&[TAG_SCALARS])?;
561 write_name(&mut f, name)?;
562 f.write_all(&(data.len() as u64).to_le_bytes())?;
563 for &v in data {
564 f.write_all(&v.to_le_bytes())?;
565 }
566 f.flush()
567 }
568 pub fn write_integers(&self, name: &str, data: &[i32]) -> io::Result<()> {
570 let mut f = self.open_append()?;
571 f.write_all(&[TAG_INTEGERS])?;
572 write_name(&mut f, name)?;
573 f.write_all(&(data.len() as u64).to_le_bytes())?;
574 for &v in data {
575 f.write_all(&v.to_le_bytes())?;
576 }
577 f.flush()
578 }
579 pub fn finalize(&self) -> io::Result<()> {
587 let existing = fs::read(&self.path)?;
588 let csum = compute_checksum(&existing);
589 {
590 let mut f = self.open_append()?;
591 f.write_all(&[TAG_FOOTER])?;
592 f.write_all(&csum.to_le_bytes())?;
593 f.flush()?;
594 }
595 if self.compress {
596 let raw = fs::read(&self.path)?;
597 let compressed = oxiarc_zstd::compress_with_level(&raw, 3)
598 .map_err(|e| io::Error::other(format!("zstd compress: {e}")))?;
599 fs::write(&self.path, &compressed)?;
600 }
601 Ok(())
602 }
603 fn open_append(&self) -> io::Result<BufWriter<File>> {
604 Ok(BufWriter::new(
605 OpenOptions::new().append(true).open(&self.path)?,
606 ))
607 }
608 fn append_vec3_block(&self, tag: u8, data: &[[f64; 3]]) -> io::Result<()> {
609 let mut f = self.open_append()?;
610 f.write_all(&[tag])?;
611 f.write_all(&(data.len() as u64).to_le_bytes())?;
612 for p in data {
613 f.write_all(&p[0].to_le_bytes())?;
614 f.write_all(&p[1].to_le_bytes())?;
615 f.write_all(&p[2].to_le_bytes())?;
616 }
617 f.flush()
618 }
619}
620#[allow(dead_code)]
622#[derive(Debug, Clone)]
623pub struct DeltaCheckpoint {
624 pub base_step: u64,
626 pub target_step: u64,
628 pub changed_indices: Vec<usize>,
630 pub positions: Vec<[f64; 3]>,
632 pub velocities: Vec<[f64; 3]>,
634}
635impl DeltaCheckpoint {
636 pub fn compute(
641 base_step: u64,
642 target_step: u64,
643 base: &SimulationState,
644 target: &SimulationState,
645 tol: f64,
646 ) -> Self {
647 let n = base.positions.len().min(target.positions.len());
648 let mut changed_indices = Vec::new();
649 let mut positions = Vec::new();
650 let mut velocities = Vec::new();
651 for i in 0..n {
652 let pos_changed = base.positions[i]
653 .iter()
654 .zip(target.positions[i].iter())
655 .any(|(a, b)| (a - b).abs() > tol);
656 let vel_changed = if i < base.velocities.len() && i < target.velocities.len() {
657 base.velocities[i]
658 .iter()
659 .zip(target.velocities[i].iter())
660 .any(|(a, b)| (a - b).abs() > tol)
661 } else {
662 false
663 };
664 if pos_changed || vel_changed {
665 changed_indices.push(i);
666 positions.push(target.positions[i]);
667 if i < target.velocities.len() {
668 velocities.push(target.velocities[i]);
669 } else {
670 velocities.push([0.0; 3]);
671 }
672 }
673 }
674 Self {
675 base_step,
676 target_step,
677 changed_indices,
678 positions,
679 velocities,
680 }
681 }
682 pub fn num_changed(&self) -> usize {
684 self.changed_indices.len()
685 }
686 pub fn byte_size(&self) -> usize {
688 16 + self.changed_indices.len() * (8 + 3 * 8 + 3 * 8)
689 }
690 pub fn apply(&self, base: &SimulationState) -> SimulationState {
692 let mut out = base.clone();
693 for (k, &idx) in self.changed_indices.iter().enumerate() {
694 if idx < out.positions.len() {
695 out.positions[idx] = self.positions[k];
696 }
697 if idx < out.velocities.len() && k < self.velocities.len() {
698 out.velocities[idx] = self.velocities[k];
699 }
700 }
701 out
702 }
703}
704#[derive(Debug, Clone, PartialEq)]
706pub struct CheckpointMetadata {
707 pub step: u64,
709 pub time: f64,
711 pub n_particles: usize,
713 pub crate_version: [u32; 3],
715 pub created_at: String,
717}
718impl CheckpointMetadata {
719 pub fn new(
721 step: u64,
722 time: f64,
723 n_particles: usize,
724 crate_version: [u32; 3],
725 created_at: impl Into<String>,
726 ) -> Self {
727 Self {
728 step,
729 time,
730 n_particles,
731 crate_version,
732 created_at: created_at.into(),
733 }
734 }
735 pub(crate) fn to_bytes(&self) -> Vec<u8> {
737 let mut buf = Vec::new();
738 buf.extend_from_slice(&self.step.to_le_bytes());
739 buf.extend_from_slice(&self.time.to_le_bytes());
740 buf.extend_from_slice(&(self.n_particles as u64).to_le_bytes());
741 for v in &self.crate_version {
742 buf.extend_from_slice(&v.to_le_bytes());
743 }
744 let ts = self.created_at.as_bytes();
745 buf.extend_from_slice(&(ts.len() as u32).to_le_bytes());
746 buf.extend_from_slice(ts);
747 buf
748 }
749 pub(crate) fn from_bytes(data: &[u8]) -> io::Result<Self> {
751 let mut cursor = 0usize;
752 let step = read_u64(data, &mut cursor)?;
753 let time = read_f64(data, &mut cursor)?;
754 let n_particles = read_u64(data, &mut cursor)? as usize;
755 let v0 = read_u32(data, &mut cursor)?;
756 let v1 = read_u32(data, &mut cursor)?;
757 let v2 = read_u32(data, &mut cursor)?;
758 let ts_len = read_u32(data, &mut cursor)? as usize;
759 if cursor + ts_len > data.len() {
760 return Err(io::Error::new(
761 io::ErrorKind::UnexpectedEof,
762 "created_at string truncated",
763 ));
764 }
765 let created_at = String::from_utf8(data[cursor..cursor + ts_len].to_vec())
766 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("UTF-8 error: {e}")))?;
767 Ok(Self {
768 step,
769 time,
770 n_particles,
771 crate_version: [v0, v1, v2],
772 created_at,
773 })
774 }
775}
776#[allow(dead_code)]
781#[derive(Debug, Default)]
782pub struct CheckpointMerger {
783 pub(super) parts: Vec<(usize, SimulationState)>,
785}
786impl CheckpointMerger {
787 pub fn new() -> Self {
789 Self::default()
790 }
791 pub fn add_part(&mut self, domain_id: usize, state: SimulationState) {
793 self.parts.push((domain_id, state));
794 }
795 pub fn merge(&mut self) -> SimulationState {
799 self.parts.sort_by_key(|(id, _)| *id);
800 let mut merged = SimulationState::new();
801 for (_, part) in &self.parts {
802 merged.positions.extend_from_slice(&part.positions);
803 merged.velocities.extend_from_slice(&part.velocities);
804 merged.forces.extend_from_slice(&part.forces);
805 }
806 merged
807 }
808 pub fn num_parts(&self) -> usize {
810 self.parts.len()
811 }
812 pub fn total_particles(&self) -> usize {
814 self.parts.iter().map(|(_, s)| s.len()).sum()
815 }
816}
817#[derive(Debug, Clone)]
822pub struct CheckpointDiff {
823 pub base_step: u64,
825 pub target_step: u64,
827 pub edits: Vec<(usize, u8, u8)>,
829}
830impl CheckpointDiff {
831 pub fn compute(base_step: u64, base: &[u8], target_step: u64, target: &[u8]) -> Self {
835 let len = base.len().max(target.len());
836 let mut edits = Vec::new();
837 for i in 0..len {
838 let b = if i < base.len() { base[i] } else { 0 };
839 let t = if i < target.len() { target[i] } else { 0 };
840 if b != t {
841 edits.push((i, b, t));
842 }
843 }
844 Self {
845 base_step,
846 target_step,
847 edits,
848 }
849 }
850 pub fn apply(&self, base_state: &[u8]) -> io::Result<Vec<u8>> {
854 let mut out = base_state.to_vec();
855 let max_off = self.edits.iter().map(|&(o, _, _)| o).max().unwrap_or(0);
856 if max_off >= out.len() && !self.edits.is_empty() {
857 out.resize(max_off + 1, 0);
858 }
859 for &(offset, _old, new) in &self.edits {
860 if offset >= out.len() {
861 return Err(io::Error::new(
862 io::ErrorKind::InvalidData,
863 format!("diff edit offset {offset} out of bounds"),
864 ));
865 }
866 out[offset] = new;
867 }
868 Ok(out)
869 }
870 pub fn diff_size(&self) -> usize {
872 self.edits.len()
873 }
874 pub fn change_ratio(&self, base_len: usize) -> f64 {
876 if base_len == 0 {
877 return 0.0_f64;
878 }
879 self.edits.len() as f64 / base_len as f64
880 }
881}
882#[derive(Debug, Clone)]
887pub struct CheckpointManager {
888 pub base_dir: PathBuf,
890 pub max_checkpoints: usize,
892 pub interval_steps: u64,
894}
895impl CheckpointManager {
896 pub fn new(base_dir: impl Into<PathBuf>, max_checkpoints: usize, interval_steps: u64) -> Self {
898 Self {
899 base_dir: base_dir.into(),
900 max_checkpoints,
901 interval_steps,
902 }
903 }
904 pub fn should_checkpoint(&self, step: u64) -> bool {
908 if self.interval_steps == 0 {
909 return false;
910 }
911 step.is_multiple_of(self.interval_steps)
912 }
913 pub fn checkpoint_path(&self, step: u64) -> PathBuf {
915 self.base_dir.join(format!("checkpoint_{step:010}.bin"))
916 }
917 pub fn list_checkpoints(&self) -> Vec<PathBuf> {
919 let Ok(entries) = fs::read_dir(&self.base_dir) else {
920 return vec![];
921 };
922 let mut paths: Vec<PathBuf> = entries
923 .flatten()
924 .filter_map(|e| {
925 let p = e.path();
926 let name = p.file_name()?.to_string_lossy().into_owned();
927 if name.starts_with("checkpoint_") && name.ends_with(".bin") {
928 Some(p)
929 } else {
930 None
931 }
932 })
933 .collect();
934 paths.sort();
935 paths
936 }
937 pub fn latest_checkpoint(&self) -> Option<PathBuf> {
939 self.list_checkpoints().into_iter().last()
940 }
941 pub fn prune_old_checkpoints(&self) -> io::Result<()> {
943 let checkpoints = self.list_checkpoints();
944 if checkpoints.len() <= self.max_checkpoints {
945 return Ok(());
946 }
947 let to_delete = checkpoints.len() - self.max_checkpoints;
948 for path in checkpoints.iter().take(to_delete) {
949 fs::remove_file(path)?;
950 }
951 Ok(())
952 }
953}
954#[derive(Debug, Clone, Default)]
966pub struct CheckpointCompressor {
967 pub min_match_len: usize,
969 pub max_look_back: usize,
971}
972impl CheckpointCompressor {
973 pub fn new() -> Self {
975 Self {
976 min_match_len: 3,
977 max_look_back: 255,
978 }
979 }
980 pub fn compress(&self, input: &[u8]) -> Vec<u8> {
982 let min_match = self.min_match_len.max(1);
983 let look_back = self.max_look_back.max(1);
984 let mut out = Vec::new();
985 let mut pos = 0usize;
986 while pos < input.len() {
987 let window_start = pos.saturating_sub(look_back);
988 let mut best_off = 0usize;
989 let mut best_len = 0usize;
990 for start in window_start..pos {
991 let mut len = 0usize;
992 while pos + len < input.len() && input[start + len] == input[pos + len] && len < 255
993 {
994 len += 1;
995 if start + len >= pos {
996 break;
997 }
998 }
999 if len > best_len && len >= min_match {
1000 best_len = len;
1001 best_off = pos - start;
1002 }
1003 }
1004 if best_len >= min_match {
1005 out.push(0x01);
1006 out.push((best_off & 0xFF) as u8);
1007 out.push(((best_off >> 8) & 0xFF) as u8);
1008 out.push(best_len as u8);
1009 pos += best_len;
1010 } else {
1011 let run_end = (pos + 255).min(input.len());
1012 let run_len = run_end - pos;
1013 out.push(0x00);
1014 out.push(run_len as u8);
1015 out.extend_from_slice(&input[pos..pos + run_len]);
1016 pos += run_len;
1017 }
1018 }
1019 out
1020 }
1021 pub fn decompress(&self, input: &[u8]) -> io::Result<Vec<u8>> {
1023 let mut out: Vec<u8> = Vec::new();
1024 let mut i = 0usize;
1025 while i < input.len() {
1026 let tag = input[i];
1027 i += 1;
1028 match tag {
1029 0x00 => {
1030 if i >= input.len() {
1031 return Err(io::Error::new(
1032 io::ErrorKind::UnexpectedEof,
1033 "literal run truncated",
1034 ));
1035 }
1036 let run_len = input[i] as usize;
1037 i += 1;
1038 if i + run_len > input.len() {
1039 return Err(io::Error::new(
1040 io::ErrorKind::UnexpectedEof,
1041 "literal data truncated",
1042 ));
1043 }
1044 out.extend_from_slice(&input[i..i + run_len]);
1045 i += run_len;
1046 }
1047 0x01 => {
1048 if i + 3 > input.len() {
1049 return Err(io::Error::new(
1050 io::ErrorKind::UnexpectedEof,
1051 "back-ref truncated",
1052 ));
1053 }
1054 let off_lo = input[i] as usize;
1055 let off_hi = input[i + 1] as usize;
1056 let offset = off_lo | (off_hi << 8);
1057 let length = input[i + 2] as usize;
1058 i += 3;
1059 if offset == 0 || offset > out.len() {
1060 return Err(io::Error::new(
1061 io::ErrorKind::InvalidData,
1062 format!("invalid back-ref offset {offset}"),
1063 ));
1064 }
1065 let start = out.len() - offset;
1066 for k in 0..length {
1067 let byte = out[start + k];
1068 out.push(byte);
1069 }
1070 }
1071 _ => {
1072 return Err(io::Error::new(
1073 io::ErrorKind::InvalidData,
1074 format!("unknown tag 0x{tag:02X}"),
1075 ));
1076 }
1077 }
1078 }
1079 Ok(out)
1080 }
1081 pub fn compression_ratio(original_len: usize, compressed_len: usize) -> f64 {
1085 if original_len == 0 {
1086 return 1.0_f64;
1087 }
1088 compressed_len as f64 / original_len as f64
1089 }
1090}
1091#[derive(Debug, Clone)]
1096pub struct CheckpointCatalog {
1097 pub base_dir: PathBuf,
1099 pub entries: Vec<(u64, PathBuf)>,
1101}
1102impl CheckpointCatalog {
1103 pub fn scan(base_dir: impl Into<PathBuf>) -> Self {
1108 let base_dir: PathBuf = base_dir.into();
1109 let mut entries: Vec<(u64, PathBuf)> = Vec::new();
1110 if let Ok(dir_entries) = fs::read_dir(&base_dir) {
1111 for entry in dir_entries.flatten() {
1112 let path = entry.path();
1113 if let Some(name) = path.file_name().and_then(|n| n.to_str())
1114 && name.starts_with("checkpoint_")
1115 && name.ends_with(".bin")
1116 {
1117 let step_str = &name[11..name.len() - 4];
1118 if let Ok(step) = step_str.parse::<u64>() {
1119 entries.push((step, path));
1120 }
1121 }
1122 }
1123 }
1124 entries.sort_by_key(|(s, _)| *s);
1125 Self { base_dir, entries }
1126 }
1127 pub fn len(&self) -> usize {
1129 self.entries.len()
1130 }
1131 pub fn is_empty(&self) -> bool {
1133 self.entries.is_empty()
1134 }
1135 pub fn steps(&self) -> Vec<u64> {
1137 self.entries.iter().map(|(s, _)| *s).collect()
1138 }
1139 pub fn path_for_step(&self, step: u64) -> Option<&PathBuf> {
1141 self.entries
1142 .binary_search_by_key(&step, |(s, _)| *s)
1143 .ok()
1144 .map(|idx| &self.entries[idx].1)
1145 }
1146 pub fn load_step(&self, step: u64) -> io::Result<Checkpoint> {
1151 let path = self.path_for_step(step).ok_or_else(|| {
1152 io::Error::new(
1153 io::ErrorKind::NotFound,
1154 format!("step {step} not in catalog"),
1155 )
1156 })?;
1157 let data = fs::read(path)?;
1158 Checkpoint::from_bytes(&data)
1159 }
1160 pub fn latest(&self) -> Option<&PathBuf> {
1162 self.entries.last().map(|(_, p)| p)
1163 }
1164 pub fn earliest(&self) -> Option<&PathBuf> {
1166 self.entries.first().map(|(_, p)| p)
1167 }
1168 pub fn add(&mut self, checkpoint: &Checkpoint) -> io::Result<()> {
1172 let path = self
1173 .base_dir
1174 .join(format!("checkpoint_{:010}.bin", checkpoint.step));
1175 let bytes = checkpoint.to_bytes();
1176 fs::write(&path, &bytes)?;
1177 let pos = self.entries.partition_point(|(s, _)| *s < checkpoint.step);
1178 self.entries.insert(pos, (checkpoint.step, path));
1179 Ok(())
1180 }
1181 pub fn remove_step(&mut self, step: u64) -> io::Result<()> {
1183 let pos = self
1184 .entries
1185 .binary_search_by_key(&step, |(s, _)| *s)
1186 .map_err(|_| {
1187 io::Error::new(
1188 io::ErrorKind::NotFound,
1189 format!("step {step} not in catalog"),
1190 )
1191 })?;
1192 let (_, path) = self.entries.remove(pos);
1193 if path.exists() {
1194 fs::remove_file(&path)?;
1195 }
1196 Ok(())
1197 }
1198}
1199#[derive(Debug, Clone)]
1201pub struct CheckpointReader {
1202 pub path: PathBuf,
1204}
1205impl CheckpointReader {
1206 pub fn new(path: impl Into<PathBuf>) -> Self {
1208 Self { path: path.into() }
1209 }
1210
1211 fn read_raw_bytes(&self) -> io::Result<Vec<u8>> {
1214 let raw = fs::read(&self.path)?;
1215 if raw.starts_with(&ZSTD_MAGIC) {
1216 oxiarc_zstd::decompress(&raw).map_err(|e| {
1217 io::Error::new(io::ErrorKind::InvalidData, format!("zstd decompress: {e}"))
1218 })
1219 } else {
1220 Ok(raw)
1221 }
1222 }
1223
1224 pub fn read_metadata(&self) -> io::Result<CheckpointMetadata> {
1226 let data = self.read_raw_bytes()?;
1227 let mut cursor = 0usize;
1228 let magic = read_u32(&data, &mut cursor)?;
1229 if magic != MAGIC {
1230 return Err(io::Error::new(
1231 io::ErrorKind::InvalidData,
1232 "bad magic number",
1233 ));
1234 }
1235 let _version = read_u32(&data, &mut cursor)?;
1236 let meta_len = read_u32(&data, &mut cursor)? as usize;
1237 if cursor + meta_len > data.len() {
1238 return Err(io::Error::new(
1239 io::ErrorKind::UnexpectedEof,
1240 "metadata block truncated",
1241 ));
1242 }
1243 CheckpointMetadata::from_bytes(&data[cursor..cursor + meta_len])
1244 }
1245 pub fn read_positions(&self) -> io::Result<Vec<[f64; 3]>> {
1247 self.read_vec3_block(TAG_POSITIONS)
1248 }
1249 pub fn read_velocities(&self) -> io::Result<Vec<[f64; 3]>> {
1251 self.read_vec3_block(TAG_VELOCITIES)
1252 }
1253 pub fn read_scalars(&self, name: &str) -> io::Result<Vec<f64>> {
1255 let data = self.read_raw_bytes()?;
1256 let mut cursor = self.skip_header(&data)?;
1257 while cursor < data.len() {
1258 let tag = data[cursor];
1259 cursor += 1;
1260 match tag {
1261 TAG_SCALARS => {
1262 let stored_name = read_name(&data, &mut cursor)?;
1263 let count = read_u64(&data, &mut cursor)? as usize;
1264 if stored_name == name {
1265 let mut out = Vec::with_capacity(count);
1266 for _ in 0..count {
1267 out.push(read_f64(&data, &mut cursor)?);
1268 }
1269 return Ok(out);
1270 } else {
1271 cursor += count * 8;
1272 }
1273 }
1274 TAG_POSITIONS | TAG_VELOCITIES => {
1275 let count = read_u64(&data, &mut cursor)? as usize;
1276 cursor += count * 24;
1277 }
1278 TAG_INTEGERS => {
1279 let _n = read_name(&data, &mut cursor)?;
1280 let count = read_u64(&data, &mut cursor)? as usize;
1281 cursor += count * 4;
1282 }
1283 TAG_FOOTER => break,
1284 _ => {
1285 return Err(io::Error::new(
1286 io::ErrorKind::InvalidData,
1287 format!("unknown tag 0x{tag:02X}"),
1288 ));
1289 }
1290 }
1291 }
1292 Err(io::Error::new(
1293 io::ErrorKind::NotFound,
1294 format!("scalar array '{name}' not found"),
1295 ))
1296 }
1297 fn skip_header(&self, data: &[u8]) -> io::Result<usize> {
1298 let mut cursor = 0usize;
1299 let _magic = read_u32(data, &mut cursor)?;
1300 let _version = read_u32(data, &mut cursor)?;
1301 let meta_len = read_u32(data, &mut cursor)? as usize;
1302 cursor += meta_len;
1303 Ok(cursor)
1304 }
1305 fn read_vec3_block(&self, target_tag: u8) -> io::Result<Vec<[f64; 3]>> {
1306 let data = self.read_raw_bytes()?;
1307 let mut cursor = self.skip_header(&data)?;
1308 while cursor < data.len() {
1309 let tag = data[cursor];
1310 cursor += 1;
1311 match tag {
1312 t if t == target_tag => {
1313 let count = read_u64(&data, &mut cursor)? as usize;
1314 let mut out = Vec::with_capacity(count);
1315 for _ in 0..count {
1316 let x = read_f64(&data, &mut cursor)?;
1317 let y = read_f64(&data, &mut cursor)?;
1318 let z = read_f64(&data, &mut cursor)?;
1319 out.push([x, y, z]);
1320 }
1321 return Ok(out);
1322 }
1323 TAG_POSITIONS | TAG_VELOCITIES => {
1324 let count = read_u64(&data, &mut cursor)? as usize;
1325 cursor += count * 24;
1326 }
1327 TAG_SCALARS | TAG_INTEGERS => {
1328 let _n = read_name(&data, &mut cursor)?;
1329 let count = read_u64(&data, &mut cursor)? as usize;
1330 let elem_size = if tag == TAG_SCALARS { 8 } else { 4 };
1331 cursor += count * elem_size;
1332 }
1333 TAG_FOOTER => break,
1334 _ => {
1335 return Err(io::Error::new(
1336 io::ErrorKind::InvalidData,
1337 format!("unknown tag 0x{tag:02X}"),
1338 ));
1339 }
1340 }
1341 }
1342 Ok(vec![])
1343 }
1344}
1345#[derive(Debug, Clone)]
1347pub struct CheckpointFileReader {
1348 pub path: PathBuf,
1350}
1351impl CheckpointFileReader {
1352 pub fn new(path: impl Into<PathBuf>) -> Self {
1354 Self { path: path.into() }
1355 }
1356 pub fn read_and_validate(&self) -> io::Result<Checkpoint> {
1361 let data = fs::read(&self.path)?;
1362 let ckpt = Checkpoint::from_bytes(&data)?;
1363 if !ckpt.verify() {
1364 return Err(io::Error::new(
1365 io::ErrorKind::InvalidData,
1366 "checkpoint checksum mismatch",
1367 ));
1368 }
1369 Ok(ckpt)
1370 }
1371}