1use std::collections::HashMap;
6use std::fs::{self, File, OpenOptions};
7use std::io::{self, BufWriter, Write};
8use std::path::{Path, PathBuf};
9
10#[allow(unused_imports)]
11use super::functions::*;
12use super::functions::{
13 FORMAT_VERSION, MAGIC, TAG_FOOTER, TAG_INTEGERS, TAG_POSITIONS, TAG_SCALARS, TAG_VELOCITIES,
14};
15
16#[derive(Debug, Clone, PartialEq)]
21pub struct Checkpoint {
22 pub version: u32,
24 pub timestamp: u64,
26 pub step: u64,
28 pub sim_time: f64,
30 pub state: Vec<u8>,
32 pub checksum: u32,
34}
35impl Checkpoint {
36 pub fn new(version: u32, timestamp: u64, step: u64, sim_time: f64, state: Vec<u8>) -> Self {
38 let checksum = compute_checksum(&state);
39 Self {
40 version,
41 timestamp,
42 step,
43 sim_time,
44 state,
45 checksum,
46 }
47 }
48 pub fn compute_checksum(&mut self) {
50 self.checksum = compute_checksum(&self.state);
51 }
52 pub fn verify(&self) -> bool {
54 compute_checksum(&self.state) == self.checksum
55 }
56 pub fn to_bytes(&self) -> Vec<u8> {
67 let mut buf = Vec::new();
68 buf.extend_from_slice(&self.version.to_le_bytes());
69 buf.extend_from_slice(&self.timestamp.to_le_bytes());
70 buf.extend_from_slice(&self.step.to_le_bytes());
71 buf.extend_from_slice(&self.sim_time.to_bits().to_le_bytes());
72 buf.extend_from_slice(&self.checksum.to_le_bytes());
73 buf.extend_from_slice(&(self.state.len() as u64).to_le_bytes());
74 buf.extend_from_slice(&self.state);
75 buf
76 }
77 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
79 let mut cursor = 0usize;
80 let version = read_u32(data, &mut cursor)?;
81 let timestamp = read_u64(data, &mut cursor)?;
82 let step = read_u64(data, &mut cursor)?;
83 let sim_time = read_f64(data, &mut cursor)?;
84 let checksum = read_u32(data, &mut cursor)?;
85 let state_len = read_u64(data, &mut cursor)? as usize;
86 if cursor + state_len > data.len() {
87 return Err(io::Error::new(
88 io::ErrorKind::UnexpectedEof,
89 "state truncated",
90 ));
91 }
92 let state = data[cursor..cursor + state_len].to_vec();
93 Ok(Self {
94 version,
95 timestamp,
96 step,
97 sim_time,
98 state,
99 checksum,
100 })
101 }
102}
103#[allow(dead_code)]
105#[derive(Debug)]
106pub struct CheckpointInspector {
107 pub base_dir: PathBuf,
109}
110impl CheckpointInspector {
111 pub fn new(base_dir: impl Into<PathBuf>) -> Self {
113 Self {
114 base_dir: base_dir.into(),
115 }
116 }
117 pub fn list(&self) -> Vec<PathBuf> {
119 let mut paths: Vec<PathBuf> = match fs::read_dir(&self.base_dir) {
120 Ok(rd) => rd
121 .filter_map(|e| e.ok())
122 .map(|e| e.path())
123 .filter(|p| p.extension().is_some_and(|ext| ext == "bin"))
124 .collect(),
125 Err(_) => Vec::new(),
126 };
127 paths.sort();
128 paths
129 }
130 pub fn peek_header(&self, path: &Path) -> io::Result<Checkpoint> {
132 let data = fs::read(path)?;
133 Checkpoint::from_bytes(&data)
134 }
135 pub fn metadata_summary(&self) -> Vec<(u64, f64)> {
137 self.list()
138 .iter()
139 .filter_map(|p| self.peek_header(p).ok())
140 .map(|c| (c.step, c.sim_time))
141 .collect()
142 }
143 pub fn count(&self) -> usize {
145 self.list().len()
146 }
147}
148#[derive(Debug, Clone)]
151pub struct RestartFile {
152 pub meta: CheckpointMetadata,
154 pub positions: Vec<[f64; 3]>,
156 pub velocities: Vec<[f64; 3]>,
158 pub scalars: HashMap<String, Vec<f64>>,
160}
161impl RestartFile {
162 pub fn new(
164 meta: CheckpointMetadata,
165 positions: Vec<[f64; 3]>,
166 velocities: Vec<[f64; 3]>,
167 scalars: HashMap<String, Vec<f64>>,
168 ) -> Self {
169 Self {
170 meta,
171 positions,
172 velocities,
173 scalars,
174 }
175 }
176 pub fn save(&self, path: &Path) -> io::Result<()> {
178 let writer = CheckpointWriter::new(path);
179 writer.write_header(&self.meta)?;
180 writer.write_positions(&self.positions)?;
181 writer.write_velocities(&self.velocities)?;
182 let mut keys: Vec<&String> = self.scalars.keys().collect();
183 keys.sort();
184 for k in keys {
185 writer.write_scalars(k, &self.scalars[k])?;
186 }
187 writer.finalize()
188 }
189 pub fn load(path: &Path) -> io::Result<Self> {
191 let reader = CheckpointReader::new(path);
192 let meta = reader.read_metadata()?;
193 let positions = reader.read_positions()?;
194 let velocities = reader.read_velocities()?;
195 let data = fs::read(path)?;
196 let mut scalars: HashMap<String, Vec<f64>> = HashMap::new();
197 let mut cursor = {
198 let mut c = 0usize;
199 let _magic = read_u32(&data, &mut c)?;
200 let _version = read_u32(&data, &mut c)?;
201 let meta_len = read_u32(&data, &mut c)? as usize;
202 c += meta_len;
203 c
204 };
205 while cursor < data.len() {
206 let tag = data[cursor];
207 cursor += 1;
208 match tag {
209 TAG_POSITIONS | TAG_VELOCITIES => {
210 let count = read_u64(&data, &mut cursor)? as usize;
211 cursor += count * 24;
212 }
213 TAG_SCALARS => {
214 let name = read_name(&data, &mut cursor)?;
215 let count = read_u64(&data, &mut cursor)? as usize;
216 let mut vals = Vec::with_capacity(count);
217 for _ in 0..count {
218 vals.push(read_f64(&data, &mut cursor)?);
219 }
220 scalars.insert(name, vals);
221 }
222 TAG_INTEGERS => {
223 let _name = read_name(&data, &mut cursor)?;
224 let count = read_u64(&data, &mut cursor)? as usize;
225 cursor += count * 4;
226 }
227 TAG_FOOTER => break,
228 _ => {
229 return Err(io::Error::new(
230 io::ErrorKind::InvalidData,
231 format!("unknown tag 0x{tag:02X} while loading restart"),
232 ));
233 }
234 }
235 }
236 Ok(Self {
237 meta,
238 positions,
239 velocities,
240 scalars,
241 })
242 }
243}
244#[allow(dead_code)]
246#[derive(Debug, Clone, Copy, PartialEq, Eq)]
247pub enum CheckpointFormat {
248 Binary,
250 Json,
252 Compressed,
254 HDF5Like,
256}
257impl CheckpointFormat {
258 pub fn extension(self) -> &'static str {
260 match self {
261 Self::Binary => "bin",
262 Self::Json => "json",
263 Self::Compressed => "rle",
264 Self::HDF5Like => "h5xt",
265 }
266 }
267 pub fn is_text(self) -> bool {
269 matches!(self, Self::Json)
270 }
271}
272#[allow(dead_code)]
274#[derive(Debug, Clone, PartialEq)]
275pub struct CheckpointHeader {
276 pub version: [u32; 3],
278 pub timestamp: u64,
280 pub step: u64,
282 pub crate_name: String,
284 pub checksum: u32,
286}
287impl CheckpointHeader {
288 pub fn new(
290 version: [u32; 3],
291 timestamp: u64,
292 step: u64,
293 crate_name: impl Into<String>,
294 checksum: u32,
295 ) -> Self {
296 Self {
297 version,
298 timestamp,
299 step,
300 crate_name: crate_name.into(),
301 checksum,
302 }
303 }
304 pub fn to_bytes(&self) -> Vec<u8> {
306 let mut buf = Vec::new();
307 for v in &self.version {
308 buf.extend_from_slice(&v.to_le_bytes());
309 }
310 buf.extend_from_slice(&self.timestamp.to_le_bytes());
311 buf.extend_from_slice(&self.step.to_le_bytes());
312 let name_bytes = self.crate_name.as_bytes();
313 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
314 buf.extend_from_slice(name_bytes);
315 buf.extend_from_slice(&self.checksum.to_le_bytes());
316 buf
317 }
318 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
320 let mut c = 0usize;
321 let v0 = read_u32(data, &mut c)?;
322 let v1 = read_u32(data, &mut c)?;
323 let v2 = read_u32(data, &mut c)?;
324 let timestamp = read_u64(data, &mut c)?;
325 let step = read_u64(data, &mut c)?;
326 let name_len = read_u32(data, &mut c)? as usize;
327 if c + name_len > data.len() {
328 return Err(io::Error::new(
329 io::ErrorKind::UnexpectedEof,
330 "header crate_name truncated",
331 ));
332 }
333 let crate_name = std::str::from_utf8(&data[c..c + name_len])
334 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
335 .to_owned();
336 c += name_len;
337 let checksum = read_u32(data, &mut c)?;
338 Ok(Self {
339 version: [v0, v1, v2],
340 timestamp,
341 step,
342 crate_name,
343 checksum,
344 })
345 }
346 pub fn version_compatible(&self, expected_major: u32) -> bool {
348 self.version[0] == expected_major
349 }
350}
351#[derive(Debug, Clone)]
354pub struct CheckpointFileWriter {
355 pub output_dir: PathBuf,
357}
358impl CheckpointFileWriter {
359 pub fn new(output_dir: impl Into<PathBuf>) -> Self {
361 Self {
362 output_dir: output_dir.into(),
363 }
364 }
365 pub fn write(&self, checkpoint: &Checkpoint) -> io::Result<PathBuf> {
368 let base_name = format!("checkpoint_{:010}", checkpoint.step);
369 let bin_path = self.output_dir.join(format!("{base_name}.bin"));
370 let json_path = self.output_dir.join(format!("{base_name}.json"));
371 fs::write(&bin_path, checkpoint.to_bytes())?;
372 let json = format!(
373 r#"{{"version":{},"timestamp":{},"step":{},"sim_time":{},"state_len":{},"checksum":{}}}"#,
374 checkpoint.version,
375 checkpoint.timestamp,
376 checkpoint.step,
377 checkpoint.sim_time,
378 checkpoint.state.len(),
379 checkpoint.checksum
380 );
381 fs::write(&json_path, json.as_bytes())?;
382 Ok(bin_path)
383 }
384}
385#[allow(dead_code)]
387#[derive(Debug, Clone)]
388pub struct SimulationState {
389 pub positions: Vec<[f64; 3]>,
391 pub velocities: Vec<[f64; 3]>,
393 pub forces: Vec<[f64; 3]>,
395 pub metadata: HashMap<String, f64>,
397}
398impl SimulationState {
399 pub fn new() -> Self {
401 Self {
402 positions: Vec::new(),
403 velocities: Vec::new(),
404 forces: Vec::new(),
405 metadata: HashMap::new(),
406 }
407 }
408 pub fn len(&self) -> usize {
410 self.positions.len()
411 }
412 pub fn is_empty(&self) -> bool {
414 self.positions.is_empty()
415 }
416 pub fn to_bytes(&self) -> Vec<u8> {
418 let mut buf = Vec::new();
419 let n = self.positions.len() as u64;
420 buf.extend_from_slice(&n.to_le_bytes());
421 for pos in &self.positions {
422 for &c in pos {
423 buf.extend_from_slice(&c.to_le_bytes());
424 }
425 }
426 for vel in &self.velocities {
427 for &c in vel {
428 buf.extend_from_slice(&c.to_le_bytes());
429 }
430 }
431 for frc in &self.forces {
432 for &c in frc {
433 buf.extend_from_slice(&c.to_le_bytes());
434 }
435 }
436 buf
437 }
438 pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
440 let mut cur = 0usize;
441 let n = read_u64(data, &mut cur)? as usize;
442 let mut positions = Vec::with_capacity(n);
443 let mut velocities = Vec::with_capacity(n);
444 let mut forces = Vec::with_capacity(n);
445 for _ in 0..n {
446 let x = read_f64(data, &mut cur)?;
447 let y = read_f64(data, &mut cur)?;
448 let z = read_f64(data, &mut cur)?;
449 positions.push([x, y, z]);
450 }
451 for _ in 0..n {
452 let x = read_f64(data, &mut cur)?;
453 let y = read_f64(data, &mut cur)?;
454 let z = read_f64(data, &mut cur)?;
455 velocities.push([x, y, z]);
456 }
457 for _ in 0..n {
458 let x = read_f64(data, &mut cur)?;
459 let y = read_f64(data, &mut cur)?;
460 let z = read_f64(data, &mut cur)?;
461 forces.push([x, y, z]);
462 }
463 Ok(Self {
464 positions,
465 velocities,
466 forces,
467 metadata: HashMap::new(),
468 })
469 }
470}
471#[allow(dead_code)]
473#[derive(Debug, Clone)]
474pub enum RestartStrategy {
475 FromLatest,
477 FromStep(u64),
479 FromFile(String),
481}
482impl RestartStrategy {
483 pub fn resolve(&self, manager: &CheckpointManager) -> Option<PathBuf> {
485 match self {
486 Self::FromLatest => manager.latest_checkpoint(),
487 Self::FromStep(step) => {
488 let p = manager.checkpoint_path(*step);
489 if p.exists() { Some(p) } else { None }
490 }
491 Self::FromFile(path) => {
492 let p = PathBuf::from(path);
493 if p.exists() { Some(p) } else { None }
494 }
495 }
496 }
497 pub fn is_latest(&self) -> bool {
499 matches!(self, Self::FromLatest)
500 }
501}
502#[derive(Debug, Clone)]
512pub struct CheckpointWriter {
513 pub path: PathBuf,
515 pub compress: bool,
517}
518impl CheckpointWriter {
519 pub fn new(path: impl Into<PathBuf>) -> Self {
521 Self {
522 path: path.into(),
523 compress: false,
524 }
525 }
526 pub fn with_compress(mut self, compress: bool) -> Self {
528 self.compress = compress;
529 self
530 }
531 pub fn write_header(&self, meta: &CheckpointMetadata) -> io::Result<()> {
536 let mut f = BufWriter::new(File::create(&self.path)?);
537 f.write_all(&MAGIC.to_le_bytes())?;
538 f.write_all(&FORMAT_VERSION.to_le_bytes())?;
539 let meta_bytes = meta.to_bytes();
540 f.write_all(&(meta_bytes.len() as u32).to_le_bytes())?;
541 f.write_all(&meta_bytes)?;
542 f.flush()
543 }
544 pub fn write_positions(&self, pos: &[[f64; 3]]) -> io::Result<()> {
548 self.append_vec3_block(TAG_POSITIONS, pos)
549 }
550 pub fn write_velocities(&self, vel: &[[f64; 3]]) -> io::Result<()> {
552 self.append_vec3_block(TAG_VELOCITIES, vel)
553 }
554 pub fn write_scalars(&self, name: &str, data: &[f64]) -> io::Result<()> {
556 let mut f = self.open_append()?;
557 f.write_all(&[TAG_SCALARS])?;
558 write_name(&mut f, name)?;
559 f.write_all(&(data.len() as u64).to_le_bytes())?;
560 for &v in data {
561 f.write_all(&v.to_le_bytes())?;
562 }
563 f.flush()
564 }
565 pub fn write_integers(&self, name: &str, data: &[i32]) -> io::Result<()> {
567 let mut f = self.open_append()?;
568 f.write_all(&[TAG_INTEGERS])?;
569 write_name(&mut f, name)?;
570 f.write_all(&(data.len() as u64).to_le_bytes())?;
571 for &v in data {
572 f.write_all(&v.to_le_bytes())?;
573 }
574 f.flush()
575 }
576 pub fn finalize(&self) -> io::Result<()> {
581 let existing = fs::read(&self.path)?;
582 let csum = compute_checksum(&existing);
583 let mut f = self.open_append()?;
584 f.write_all(&[TAG_FOOTER])?;
585 f.write_all(&csum.to_le_bytes())?;
586 f.flush()
587 }
588 fn open_append(&self) -> io::Result<BufWriter<File>> {
589 Ok(BufWriter::new(
590 OpenOptions::new().append(true).open(&self.path)?,
591 ))
592 }
593 fn append_vec3_block(&self, tag: u8, data: &[[f64; 3]]) -> io::Result<()> {
594 let mut f = self.open_append()?;
595 f.write_all(&[tag])?;
596 f.write_all(&(data.len() as u64).to_le_bytes())?;
597 for p in data {
598 f.write_all(&p[0].to_le_bytes())?;
599 f.write_all(&p[1].to_le_bytes())?;
600 f.write_all(&p[2].to_le_bytes())?;
601 }
602 f.flush()
603 }
604}
605#[allow(dead_code)]
607#[derive(Debug, Clone)]
608pub struct DeltaCheckpoint {
609 pub base_step: u64,
611 pub target_step: u64,
613 pub changed_indices: Vec<usize>,
615 pub positions: Vec<[f64; 3]>,
617 pub velocities: Vec<[f64; 3]>,
619}
620impl DeltaCheckpoint {
621 pub fn compute(
626 base_step: u64,
627 target_step: u64,
628 base: &SimulationState,
629 target: &SimulationState,
630 tol: f64,
631 ) -> Self {
632 let n = base.positions.len().min(target.positions.len());
633 let mut changed_indices = Vec::new();
634 let mut positions = Vec::new();
635 let mut velocities = Vec::new();
636 for i in 0..n {
637 let pos_changed = base.positions[i]
638 .iter()
639 .zip(target.positions[i].iter())
640 .any(|(a, b)| (a - b).abs() > tol);
641 let vel_changed = if i < base.velocities.len() && i < target.velocities.len() {
642 base.velocities[i]
643 .iter()
644 .zip(target.velocities[i].iter())
645 .any(|(a, b)| (a - b).abs() > tol)
646 } else {
647 false
648 };
649 if pos_changed || vel_changed {
650 changed_indices.push(i);
651 positions.push(target.positions[i]);
652 if i < target.velocities.len() {
653 velocities.push(target.velocities[i]);
654 } else {
655 velocities.push([0.0; 3]);
656 }
657 }
658 }
659 Self {
660 base_step,
661 target_step,
662 changed_indices,
663 positions,
664 velocities,
665 }
666 }
667 pub fn num_changed(&self) -> usize {
669 self.changed_indices.len()
670 }
671 pub fn byte_size(&self) -> usize {
673 16 + self.changed_indices.len() * (8 + 3 * 8 + 3 * 8)
674 }
675 pub fn apply(&self, base: &SimulationState) -> SimulationState {
677 let mut out = base.clone();
678 for (k, &idx) in self.changed_indices.iter().enumerate() {
679 if idx < out.positions.len() {
680 out.positions[idx] = self.positions[k];
681 }
682 if idx < out.velocities.len() && k < self.velocities.len() {
683 out.velocities[idx] = self.velocities[k];
684 }
685 }
686 out
687 }
688}
689#[derive(Debug, Clone, PartialEq)]
691pub struct CheckpointMetadata {
692 pub step: u64,
694 pub time: f64,
696 pub n_particles: usize,
698 pub crate_version: [u32; 3],
700 pub created_at: String,
702}
703impl CheckpointMetadata {
704 pub fn new(
706 step: u64,
707 time: f64,
708 n_particles: usize,
709 crate_version: [u32; 3],
710 created_at: impl Into<String>,
711 ) -> Self {
712 Self {
713 step,
714 time,
715 n_particles,
716 crate_version,
717 created_at: created_at.into(),
718 }
719 }
720 pub(crate) fn to_bytes(&self) -> Vec<u8> {
722 let mut buf = Vec::new();
723 buf.extend_from_slice(&self.step.to_le_bytes());
724 buf.extend_from_slice(&self.time.to_le_bytes());
725 buf.extend_from_slice(&(self.n_particles as u64).to_le_bytes());
726 for v in &self.crate_version {
727 buf.extend_from_slice(&v.to_le_bytes());
728 }
729 let ts = self.created_at.as_bytes();
730 buf.extend_from_slice(&(ts.len() as u32).to_le_bytes());
731 buf.extend_from_slice(ts);
732 buf
733 }
734 pub(crate) fn from_bytes(data: &[u8]) -> io::Result<Self> {
736 let mut cursor = 0usize;
737 let step = read_u64(data, &mut cursor)?;
738 let time = read_f64(data, &mut cursor)?;
739 let n_particles = read_u64(data, &mut cursor)? as usize;
740 let v0 = read_u32(data, &mut cursor)?;
741 let v1 = read_u32(data, &mut cursor)?;
742 let v2 = read_u32(data, &mut cursor)?;
743 let ts_len = read_u32(data, &mut cursor)? as usize;
744 if cursor + ts_len > data.len() {
745 return Err(io::Error::new(
746 io::ErrorKind::UnexpectedEof,
747 "created_at string truncated",
748 ));
749 }
750 let created_at = String::from_utf8(data[cursor..cursor + ts_len].to_vec())
751 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, format!("UTF-8 error: {e}")))?;
752 Ok(Self {
753 step,
754 time,
755 n_particles,
756 crate_version: [v0, v1, v2],
757 created_at,
758 })
759 }
760}
761#[allow(dead_code)]
766#[derive(Debug, Default)]
767pub struct CheckpointMerger {
768 pub(super) parts: Vec<(usize, SimulationState)>,
770}
771impl CheckpointMerger {
772 pub fn new() -> Self {
774 Self::default()
775 }
776 pub fn add_part(&mut self, domain_id: usize, state: SimulationState) {
778 self.parts.push((domain_id, state));
779 }
780 pub fn merge(&mut self) -> SimulationState {
784 self.parts.sort_by_key(|(id, _)| *id);
785 let mut merged = SimulationState::new();
786 for (_, part) in &self.parts {
787 merged.positions.extend_from_slice(&part.positions);
788 merged.velocities.extend_from_slice(&part.velocities);
789 merged.forces.extend_from_slice(&part.forces);
790 }
791 merged
792 }
793 pub fn num_parts(&self) -> usize {
795 self.parts.len()
796 }
797 pub fn total_particles(&self) -> usize {
799 self.parts.iter().map(|(_, s)| s.len()).sum()
800 }
801}
802#[derive(Debug, Clone)]
807pub struct CheckpointDiff {
808 pub base_step: u64,
810 pub target_step: u64,
812 pub edits: Vec<(usize, u8, u8)>,
814}
815impl CheckpointDiff {
816 pub fn compute(base_step: u64, base: &[u8], target_step: u64, target: &[u8]) -> Self {
820 let len = base.len().max(target.len());
821 let mut edits = Vec::new();
822 for i in 0..len {
823 let b = if i < base.len() { base[i] } else { 0 };
824 let t = if i < target.len() { target[i] } else { 0 };
825 if b != t {
826 edits.push((i, b, t));
827 }
828 }
829 Self {
830 base_step,
831 target_step,
832 edits,
833 }
834 }
835 pub fn apply(&self, base_state: &[u8]) -> io::Result<Vec<u8>> {
839 let mut out = base_state.to_vec();
840 let max_off = self.edits.iter().map(|&(o, _, _)| o).max().unwrap_or(0);
841 if max_off >= out.len() && !self.edits.is_empty() {
842 out.resize(max_off + 1, 0);
843 }
844 for &(offset, _old, new) in &self.edits {
845 if offset >= out.len() {
846 return Err(io::Error::new(
847 io::ErrorKind::InvalidData,
848 format!("diff edit offset {offset} out of bounds"),
849 ));
850 }
851 out[offset] = new;
852 }
853 Ok(out)
854 }
855 pub fn diff_size(&self) -> usize {
857 self.edits.len()
858 }
859 pub fn change_ratio(&self, base_len: usize) -> f64 {
861 if base_len == 0 {
862 return 0.0_f64;
863 }
864 self.edits.len() as f64 / base_len as f64
865 }
866}
867#[derive(Debug, Clone)]
872pub struct CheckpointManager {
873 pub base_dir: PathBuf,
875 pub max_checkpoints: usize,
877 pub interval_steps: u64,
879}
880impl CheckpointManager {
881 pub fn new(base_dir: impl Into<PathBuf>, max_checkpoints: usize, interval_steps: u64) -> Self {
883 Self {
884 base_dir: base_dir.into(),
885 max_checkpoints,
886 interval_steps,
887 }
888 }
889 pub fn should_checkpoint(&self, step: u64) -> bool {
893 if self.interval_steps == 0 {
894 return false;
895 }
896 step.is_multiple_of(self.interval_steps)
897 }
898 pub fn checkpoint_path(&self, step: u64) -> PathBuf {
900 self.base_dir.join(format!("checkpoint_{step:010}.bin"))
901 }
902 pub fn list_checkpoints(&self) -> Vec<PathBuf> {
904 let Ok(entries) = fs::read_dir(&self.base_dir) else {
905 return vec![];
906 };
907 let mut paths: Vec<PathBuf> = entries
908 .flatten()
909 .filter_map(|e| {
910 let p = e.path();
911 let name = p.file_name()?.to_string_lossy().into_owned();
912 if name.starts_with("checkpoint_") && name.ends_with(".bin") {
913 Some(p)
914 } else {
915 None
916 }
917 })
918 .collect();
919 paths.sort();
920 paths
921 }
922 pub fn latest_checkpoint(&self) -> Option<PathBuf> {
924 self.list_checkpoints().into_iter().last()
925 }
926 pub fn prune_old_checkpoints(&self) -> io::Result<()> {
928 let checkpoints = self.list_checkpoints();
929 if checkpoints.len() <= self.max_checkpoints {
930 return Ok(());
931 }
932 let to_delete = checkpoints.len() - self.max_checkpoints;
933 for path in checkpoints.iter().take(to_delete) {
934 fs::remove_file(path)?;
935 }
936 Ok(())
937 }
938}
939#[derive(Debug, Clone, Default)]
951pub struct CheckpointCompressor {
952 pub min_match_len: usize,
954 pub max_look_back: usize,
956}
957impl CheckpointCompressor {
958 pub fn new() -> Self {
960 Self {
961 min_match_len: 3,
962 max_look_back: 255,
963 }
964 }
965 pub fn compress(&self, input: &[u8]) -> Vec<u8> {
967 let min_match = self.min_match_len.max(1);
968 let look_back = self.max_look_back.max(1);
969 let mut out = Vec::new();
970 let mut pos = 0usize;
971 while pos < input.len() {
972 let window_start = pos.saturating_sub(look_back);
973 let mut best_off = 0usize;
974 let mut best_len = 0usize;
975 for start in window_start..pos {
976 let mut len = 0usize;
977 while pos + len < input.len() && input[start + len] == input[pos + len] && len < 255
978 {
979 len += 1;
980 if start + len >= pos {
981 break;
982 }
983 }
984 if len > best_len && len >= min_match {
985 best_len = len;
986 best_off = pos - start;
987 }
988 }
989 if best_len >= min_match {
990 out.push(0x01);
991 out.push((best_off & 0xFF) as u8);
992 out.push(((best_off >> 8) & 0xFF) as u8);
993 out.push(best_len as u8);
994 pos += best_len;
995 } else {
996 let run_end = (pos + 255).min(input.len());
997 let run_len = run_end - pos;
998 out.push(0x00);
999 out.push(run_len as u8);
1000 out.extend_from_slice(&input[pos..pos + run_len]);
1001 pos += run_len;
1002 }
1003 }
1004 out
1005 }
1006 pub fn decompress(&self, input: &[u8]) -> io::Result<Vec<u8>> {
1008 let mut out: Vec<u8> = Vec::new();
1009 let mut i = 0usize;
1010 while i < input.len() {
1011 let tag = input[i];
1012 i += 1;
1013 match tag {
1014 0x00 => {
1015 if i >= input.len() {
1016 return Err(io::Error::new(
1017 io::ErrorKind::UnexpectedEof,
1018 "literal run truncated",
1019 ));
1020 }
1021 let run_len = input[i] as usize;
1022 i += 1;
1023 if i + run_len > input.len() {
1024 return Err(io::Error::new(
1025 io::ErrorKind::UnexpectedEof,
1026 "literal data truncated",
1027 ));
1028 }
1029 out.extend_from_slice(&input[i..i + run_len]);
1030 i += run_len;
1031 }
1032 0x01 => {
1033 if i + 3 > input.len() {
1034 return Err(io::Error::new(
1035 io::ErrorKind::UnexpectedEof,
1036 "back-ref truncated",
1037 ));
1038 }
1039 let off_lo = input[i] as usize;
1040 let off_hi = input[i + 1] as usize;
1041 let offset = off_lo | (off_hi << 8);
1042 let length = input[i + 2] as usize;
1043 i += 3;
1044 if offset == 0 || offset > out.len() {
1045 return Err(io::Error::new(
1046 io::ErrorKind::InvalidData,
1047 format!("invalid back-ref offset {offset}"),
1048 ));
1049 }
1050 let start = out.len() - offset;
1051 for k in 0..length {
1052 let byte = out[start + k];
1053 out.push(byte);
1054 }
1055 }
1056 _ => {
1057 return Err(io::Error::new(
1058 io::ErrorKind::InvalidData,
1059 format!("unknown tag 0x{tag:02X}"),
1060 ));
1061 }
1062 }
1063 }
1064 Ok(out)
1065 }
1066 pub fn compression_ratio(original_len: usize, compressed_len: usize) -> f64 {
1070 if original_len == 0 {
1071 return 1.0_f64;
1072 }
1073 compressed_len as f64 / original_len as f64
1074 }
1075}
1076#[derive(Debug, Clone)]
1081pub struct CheckpointCatalog {
1082 pub base_dir: PathBuf,
1084 pub entries: Vec<(u64, PathBuf)>,
1086}
1087impl CheckpointCatalog {
1088 pub fn scan(base_dir: impl Into<PathBuf>) -> Self {
1093 let base_dir: PathBuf = base_dir.into();
1094 let mut entries: Vec<(u64, PathBuf)> = Vec::new();
1095 if let Ok(dir_entries) = fs::read_dir(&base_dir) {
1096 for entry in dir_entries.flatten() {
1097 let path = entry.path();
1098 if let Some(name) = path.file_name().and_then(|n| n.to_str())
1099 && name.starts_with("checkpoint_")
1100 && name.ends_with(".bin")
1101 {
1102 let step_str = &name[11..name.len() - 4];
1103 if let Ok(step) = step_str.parse::<u64>() {
1104 entries.push((step, path));
1105 }
1106 }
1107 }
1108 }
1109 entries.sort_by_key(|(s, _)| *s);
1110 Self { base_dir, entries }
1111 }
1112 pub fn len(&self) -> usize {
1114 self.entries.len()
1115 }
1116 pub fn is_empty(&self) -> bool {
1118 self.entries.is_empty()
1119 }
1120 pub fn steps(&self) -> Vec<u64> {
1122 self.entries.iter().map(|(s, _)| *s).collect()
1123 }
1124 pub fn path_for_step(&self, step: u64) -> Option<&PathBuf> {
1126 self.entries
1127 .binary_search_by_key(&step, |(s, _)| *s)
1128 .ok()
1129 .map(|idx| &self.entries[idx].1)
1130 }
1131 pub fn load_step(&self, step: u64) -> io::Result<Checkpoint> {
1136 let path = self.path_for_step(step).ok_or_else(|| {
1137 io::Error::new(
1138 io::ErrorKind::NotFound,
1139 format!("step {step} not in catalog"),
1140 )
1141 })?;
1142 let data = fs::read(path)?;
1143 Checkpoint::from_bytes(&data)
1144 }
1145 pub fn latest(&self) -> Option<&PathBuf> {
1147 self.entries.last().map(|(_, p)| p)
1148 }
1149 pub fn earliest(&self) -> Option<&PathBuf> {
1151 self.entries.first().map(|(_, p)| p)
1152 }
1153 pub fn add(&mut self, checkpoint: &Checkpoint) -> io::Result<()> {
1157 let path = self
1158 .base_dir
1159 .join(format!("checkpoint_{:010}.bin", checkpoint.step));
1160 let bytes = checkpoint.to_bytes();
1161 fs::write(&path, &bytes)?;
1162 let pos = self.entries.partition_point(|(s, _)| *s < checkpoint.step);
1163 self.entries.insert(pos, (checkpoint.step, path));
1164 Ok(())
1165 }
1166 pub fn remove_step(&mut self, step: u64) -> io::Result<()> {
1168 let pos = self
1169 .entries
1170 .binary_search_by_key(&step, |(s, _)| *s)
1171 .map_err(|_| {
1172 io::Error::new(
1173 io::ErrorKind::NotFound,
1174 format!("step {step} not in catalog"),
1175 )
1176 })?;
1177 let (_, path) = self.entries.remove(pos);
1178 if path.exists() {
1179 fs::remove_file(&path)?;
1180 }
1181 Ok(())
1182 }
1183}
1184#[derive(Debug, Clone)]
1186pub struct CheckpointReader {
1187 pub path: PathBuf,
1189}
1190impl CheckpointReader {
1191 pub fn new(path: impl Into<PathBuf>) -> Self {
1193 Self { path: path.into() }
1194 }
1195 pub fn read_metadata(&self) -> io::Result<CheckpointMetadata> {
1197 let data = fs::read(&self.path)?;
1198 let mut cursor = 0usize;
1199 let magic = read_u32(&data, &mut cursor)?;
1200 if magic != MAGIC {
1201 return Err(io::Error::new(
1202 io::ErrorKind::InvalidData,
1203 "bad magic number",
1204 ));
1205 }
1206 let _version = read_u32(&data, &mut cursor)?;
1207 let meta_len = read_u32(&data, &mut cursor)? as usize;
1208 if cursor + meta_len > data.len() {
1209 return Err(io::Error::new(
1210 io::ErrorKind::UnexpectedEof,
1211 "metadata block truncated",
1212 ));
1213 }
1214 CheckpointMetadata::from_bytes(&data[cursor..cursor + meta_len])
1215 }
1216 pub fn read_positions(&self) -> io::Result<Vec<[f64; 3]>> {
1218 self.read_vec3_block(TAG_POSITIONS)
1219 }
1220 pub fn read_velocities(&self) -> io::Result<Vec<[f64; 3]>> {
1222 self.read_vec3_block(TAG_VELOCITIES)
1223 }
1224 pub fn read_scalars(&self, name: &str) -> io::Result<Vec<f64>> {
1226 let data = fs::read(&self.path)?;
1227 let mut cursor = self.skip_header(&data)?;
1228 while cursor < data.len() {
1229 let tag = data[cursor];
1230 cursor += 1;
1231 match tag {
1232 TAG_SCALARS => {
1233 let stored_name = read_name(&data, &mut cursor)?;
1234 let count = read_u64(&data, &mut cursor)? as usize;
1235 if stored_name == name {
1236 let mut out = Vec::with_capacity(count);
1237 for _ in 0..count {
1238 out.push(read_f64(&data, &mut cursor)?);
1239 }
1240 return Ok(out);
1241 } else {
1242 cursor += count * 8;
1243 }
1244 }
1245 TAG_POSITIONS | TAG_VELOCITIES => {
1246 let count = read_u64(&data, &mut cursor)? as usize;
1247 cursor += count * 24;
1248 }
1249 TAG_INTEGERS => {
1250 let _n = read_name(&data, &mut cursor)?;
1251 let count = read_u64(&data, &mut cursor)? as usize;
1252 cursor += count * 4;
1253 }
1254 TAG_FOOTER => break,
1255 _ => {
1256 return Err(io::Error::new(
1257 io::ErrorKind::InvalidData,
1258 format!("unknown tag 0x{tag:02X}"),
1259 ));
1260 }
1261 }
1262 }
1263 Err(io::Error::new(
1264 io::ErrorKind::NotFound,
1265 format!("scalar array '{name}' not found"),
1266 ))
1267 }
1268 fn skip_header(&self, data: &[u8]) -> io::Result<usize> {
1269 let mut cursor = 0usize;
1270 let _magic = read_u32(data, &mut cursor)?;
1271 let _version = read_u32(data, &mut cursor)?;
1272 let meta_len = read_u32(data, &mut cursor)? as usize;
1273 cursor += meta_len;
1274 Ok(cursor)
1275 }
1276 fn read_vec3_block(&self, target_tag: u8) -> io::Result<Vec<[f64; 3]>> {
1277 let data = fs::read(&self.path)?;
1278 let mut cursor = self.skip_header(&data)?;
1279 while cursor < data.len() {
1280 let tag = data[cursor];
1281 cursor += 1;
1282 match tag {
1283 t if t == target_tag => {
1284 let count = read_u64(&data, &mut cursor)? as usize;
1285 let mut out = Vec::with_capacity(count);
1286 for _ in 0..count {
1287 let x = read_f64(&data, &mut cursor)?;
1288 let y = read_f64(&data, &mut cursor)?;
1289 let z = read_f64(&data, &mut cursor)?;
1290 out.push([x, y, z]);
1291 }
1292 return Ok(out);
1293 }
1294 TAG_POSITIONS | TAG_VELOCITIES => {
1295 let count = read_u64(&data, &mut cursor)? as usize;
1296 cursor += count * 24;
1297 }
1298 TAG_SCALARS | TAG_INTEGERS => {
1299 let _n = read_name(&data, &mut cursor)?;
1300 let count = read_u64(&data, &mut cursor)? as usize;
1301 let elem_size = if tag == TAG_SCALARS { 8 } else { 4 };
1302 cursor += count * elem_size;
1303 }
1304 TAG_FOOTER => break,
1305 _ => {
1306 return Err(io::Error::new(
1307 io::ErrorKind::InvalidData,
1308 format!("unknown tag 0x{tag:02X}"),
1309 ));
1310 }
1311 }
1312 }
1313 Ok(vec![])
1314 }
1315}
1316#[derive(Debug, Clone)]
1318pub struct CheckpointFileReader {
1319 pub path: PathBuf,
1321}
1322impl CheckpointFileReader {
1323 pub fn new(path: impl Into<PathBuf>) -> Self {
1325 Self { path: path.into() }
1326 }
1327 pub fn read_and_validate(&self) -> io::Result<Checkpoint> {
1332 let data = fs::read(&self.path)?;
1333 let ckpt = Checkpoint::from_bytes(&data)?;
1334 if !ckpt.verify() {
1335 return Err(io::Error::new(
1336 io::ErrorKind::InvalidData,
1337 "checkpoint checksum mismatch",
1338 ));
1339 }
1340 Ok(ckpt)
1341 }
1342}