Skip to main content

oxiphysics_io/
restart_io.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Simulation restart and checkpoint I/O.
5//!
6//! Provides writers and readers for simulation restart files in multiple
7//! formats (binary, ASCII, JSON, simplified HDF5-like, MessagePack-like),
8//! plus a rolling [`CheckpointManager`] and a [`RestartValidator`] for
9//! checksum verification.
10
11#![allow(dead_code)]
12#![allow(clippy::too_many_arguments)]
13
14// ── Magic bytes ───────────────────────────────────────────────────────────────
15
16/// Magic bytes written at the start of every binary restart file: "OXRS".
17const BINARY_MAGIC: [u8; 4] = [0x4F, 0x58, 0x52, 0x53];
18
19/// Format version embedded in binary headers.
20const FORMAT_VERSION: u32 = 1;
21
22// ── RestartFormat ─────────────────────────────────────────────────────────────
23
24/// Serialization format for restart / checkpoint files.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum RestartFormat {
27    /// Custom little-endian binary format (fastest, smallest).
28    Binary,
29    /// Human-readable ASCII key-value format.
30    Ascii,
31    /// Simplified HDF5-like layout (group/dataset tags in binary).
32    Hdf5Like,
33    /// JSON text format (portable, verbose).
34    Json,
35    /// Simplified MessagePack-like binary encoding.
36    MessagePack,
37}
38
39// ── RestartMetadata ───────────────────────────────────────────────────────────
40
41/// Metadata stored alongside a restart snapshot.
42#[derive(Debug, Clone, PartialEq)]
43pub struct RestartMetadata {
44    /// File-format version string (e.g. `"1.0"`).
45    pub version: String,
46    /// UNIX timestamp (seconds since epoch) when this file was written.
47    pub timestamp: u64,
48    /// Simulation step number.
49    pub step: u64,
50    /// Simulation time (in simulation units).
51    pub time: f64,
52    /// Name of the crate that produced this file.
53    pub crate_name: String,
54    /// Free-form human-readable description.
55    pub description: String,
56}
57
58impl RestartMetadata {
59    /// Construct a new `RestartMetadata`.
60    pub fn new(
61        version: impl Into<String>,
62        timestamp: u64,
63        step: u64,
64        time: f64,
65        crate_name: impl Into<String>,
66        description: impl Into<String>,
67    ) -> Self {
68        Self {
69            version: version.into(),
70            timestamp,
71            step,
72            time,
73            crate_name: crate_name.into(),
74            description: description.into(),
75        }
76    }
77
78    /// Return a default metadata for testing.
79    pub fn default_test() -> Self {
80        Self::new("1.0", 0, 0, 0.0, "oxiphysics", "test checkpoint")
81    }
82}
83
84// ── RestartData ───────────────────────────────────────────────────────────────
85
86/// Full simulation state stored in a restart file.
87#[derive(Debug, Clone, PartialEq)]
88pub struct RestartData {
89    /// File metadata.
90    pub metadata: RestartMetadata,
91    /// Particle positions \[x, y, z\] (simulation units).
92    pub positions: Vec<[f64; 3]>,
93    /// Particle velocities \[vx, vy, vz\].
94    pub velocities: Vec<[f64; 3]>,
95    /// Per-particle forces \[fx, fy, fz\].
96    pub forces: Vec<[f64; 3]>,
97    /// Per-particle masses.
98    pub masses: Vec<f64>,
99    /// Per-particle type indices.
100    pub types: Vec<u32>,
101    /// Simulation box matrix (row-major: rows are box vectors a, b, c).
102    pub box_matrix: [[f64; 3]; 3],
103    /// Named extra scalar arrays `(name, values)`.
104    pub extra_scalars: Vec<(String, Vec<f64>)>,
105    /// Named extra vector arrays `(name, vectors)`.
106    pub extra_vectors: Vec<(String, Vec<[f64; 3]>)>,
107}
108
109impl RestartData {
110    /// Number of particles in this snapshot.
111    pub fn n_particles(&self) -> usize {
112        self.positions.len()
113    }
114
115    /// Construct a minimal empty `RestartData`.
116    pub fn empty(metadata: RestartMetadata) -> Self {
117        Self {
118            metadata,
119            positions: Vec::new(),
120            velocities: Vec::new(),
121            forces: Vec::new(),
122            masses: Vec::new(),
123            types: Vec::new(),
124            box_matrix: [[0.0; 3]; 3],
125            extra_scalars: Vec::new(),
126            extra_vectors: Vec::new(),
127        }
128    }
129
130    /// Build a simple single-particle `RestartData` for tests.
131    pub fn single_particle_test() -> Self {
132        let meta = RestartMetadata::default_test();
133        let mut d = Self::empty(meta);
134        d.positions = vec![[1.0, 2.0, 3.0]];
135        d.velocities = vec![[0.1, 0.2, 0.3]];
136        d.forces = vec![[0.0, -9.81, 0.0]];
137        d.masses = vec![1.0];
138        d.types = vec![0];
139        d.box_matrix = [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]];
140        d
141    }
142}
143
144// ── Helper: encode / decode primitives (little-endian) ────────────────────────
145
146fn encode_u32(buf: &mut Vec<u8>, v: u32) {
147    buf.extend_from_slice(&v.to_le_bytes());
148}
149
150fn encode_u64(buf: &mut Vec<u8>, v: u64) {
151    buf.extend_from_slice(&v.to_le_bytes());
152}
153
154fn encode_f64(buf: &mut Vec<u8>, v: f64) {
155    buf.extend_from_slice(&v.to_bits().to_le_bytes());
156}
157
158fn encode_str(buf: &mut Vec<u8>, s: &str) {
159    let bytes = s.as_bytes();
160    encode_u32(buf, bytes.len() as u32);
161    buf.extend_from_slice(bytes);
162}
163
164fn encode_vec3(buf: &mut Vec<u8>, v: &[f64; 3]) {
165    encode_f64(buf, v[0]);
166    encode_f64(buf, v[1]);
167    encode_f64(buf, v[2]);
168}
169
170fn decode_u32(data: &[u8], offset: &mut usize) -> Result<u32, String> {
171    if *offset + 4 > data.len() {
172        return Err(format!("unexpected EOF at offset {}", *offset));
173    }
174    let v = u32::from_le_bytes(
175        data[*offset..*offset + 4]
176            .try_into()
177            .expect("slice length must match"),
178    );
179    *offset += 4;
180    Ok(v)
181}
182
183fn decode_u64(data: &[u8], offset: &mut usize) -> Result<u64, String> {
184    if *offset + 8 > data.len() {
185        return Err(format!("unexpected EOF at offset {}", *offset));
186    }
187    let v = u64::from_le_bytes(
188        data[*offset..*offset + 8]
189            .try_into()
190            .expect("slice length must match"),
191    );
192    *offset += 8;
193    Ok(v)
194}
195
196fn decode_f64(data: &[u8], offset: &mut usize) -> Result<f64, String> {
197    if *offset + 8 > data.len() {
198        return Err(format!("unexpected EOF at offset {}", *offset));
199    }
200    let bits = u64::from_le_bytes(
201        data[*offset..*offset + 8]
202            .try_into()
203            .expect("slice length must match"),
204    );
205    *offset += 8;
206    Ok(f64::from_bits(bits))
207}
208
209fn decode_str(data: &[u8], offset: &mut usize) -> Result<String, String> {
210    let len = decode_u32(data, offset)? as usize;
211    if *offset + len > data.len() {
212        return Err(format!("string extends past EOF at offset {}", *offset));
213    }
214    let s = std::str::from_utf8(&data[*offset..*offset + len])
215        .map_err(|e| format!("UTF-8 error: {e}"))?
216        .to_string();
217    *offset += len;
218    Ok(s)
219}
220
221fn decode_vec3(data: &[u8], offset: &mut usize) -> Result<[f64; 3], String> {
222    let x = decode_f64(data, offset)?;
223    let y = decode_f64(data, offset)?;
224    let z = decode_f64(data, offset)?;
225    Ok([x, y, z])
226}
227
228// ── RestartWriter ─────────────────────────────────────────────────────────────
229
230/// Writes simulation restart data to files or in-memory buffers.
231#[derive(Debug, Clone)]
232pub struct RestartWriter {
233    /// Output file path.
234    pub path: String,
235    /// Serialization format to use.
236    pub format: RestartFormat,
237}
238
239impl RestartWriter {
240    /// Create a new `RestartWriter` targeting `path` with the given `format`.
241    pub fn new(path: &str, format: RestartFormat) -> Self {
242        Self {
243            path: path.to_string(),
244            format,
245        }
246    }
247
248    /// Serialize `data` according to `self.format` and write to `self.path`.
249    ///
250    /// Returns an error string on I/O or encoding failure.
251    pub fn write(&self, data: &RestartData) -> Result<(), String> {
252        let bytes: Vec<u8> = match &self.format {
253            RestartFormat::Binary => Self::write_binary(data),
254            RestartFormat::Ascii => Self::write_ascii(data).into_bytes(),
255            RestartFormat::Json => Self::write_json(data).into_bytes(),
256            RestartFormat::Hdf5Like => Self::write_hdf5like(data),
257            RestartFormat::MessagePack => Self::write_msgpack(data),
258        };
259        std::fs::write(&self.path, &bytes)
260            .map_err(|e| format!("failed to write '{}': {e}", self.path))
261    }
262
263    /// Encode `data` as a custom little-endian binary blob.
264    ///
265    /// Layout: `[MAGIC 4B][VERSION 4B][METADATA][N_PARTICLES 8B]
266    ///          [POSITIONS…][VELOCITIES…][FORCES…][MASSES…][TYPES…]
267    ///          [BOX_MATRIX 72B][N_EXTRA_SCALARS 4B][…][N_EXTRA_VECS 4B][…]`
268    pub fn write_binary(data: &RestartData) -> Vec<u8> {
269        let mut buf = Vec::new();
270        buf.extend_from_slice(&BINARY_MAGIC);
271        encode_u32(&mut buf, FORMAT_VERSION);
272        // metadata
273        encode_str(&mut buf, &data.metadata.version);
274        encode_u64(&mut buf, data.metadata.timestamp);
275        encode_u64(&mut buf, data.metadata.step);
276        encode_f64(&mut buf, data.metadata.time);
277        encode_str(&mut buf, &data.metadata.crate_name);
278        encode_str(&mut buf, &data.metadata.description);
279        // particle count
280        let n = data.n_particles() as u64;
281        encode_u64(&mut buf, n);
282        for p in &data.positions {
283            encode_vec3(&mut buf, p);
284        }
285        for v in &data.velocities {
286            encode_vec3(&mut buf, v);
287        }
288        for f in &data.forces {
289            encode_vec3(&mut buf, f);
290        }
291        for m in &data.masses {
292            encode_f64(&mut buf, *m);
293        }
294        for t in &data.types {
295            encode_u32(&mut buf, *t);
296        }
297        // box matrix (9 f64)
298        for row in &data.box_matrix {
299            for &c in row {
300                encode_f64(&mut buf, c);
301            }
302        }
303        // extra scalars
304        encode_u32(&mut buf, data.extra_scalars.len() as u32);
305        for (name, vals) in &data.extra_scalars {
306            encode_str(&mut buf, name);
307            encode_u64(&mut buf, vals.len() as u64);
308            for &v in vals {
309                encode_f64(&mut buf, v);
310            }
311        }
312        // extra vectors
313        encode_u32(&mut buf, data.extra_vectors.len() as u32);
314        for (name, vecs) in &data.extra_vectors {
315            encode_str(&mut buf, name);
316            encode_u64(&mut buf, vecs.len() as u64);
317            for v in vecs {
318                encode_vec3(&mut buf, v);
319            }
320        }
321        buf
322    }
323
324    /// Encode `data` as a human-readable ASCII text block.
325    pub fn write_ascii(data: &RestartData) -> String {
326        let mut s = String::new();
327        s.push_str("# OxiPhysics restart file\n");
328        s.push_str(&format!("VERSION {}\n", data.metadata.version));
329        s.push_str(&format!("TIMESTAMP {}\n", data.metadata.timestamp));
330        s.push_str(&format!("STEP {}\n", data.metadata.step));
331        s.push_str(&format!("TIME {:.6}\n", data.metadata.time));
332        s.push_str(&format!("CRATE {}\n", data.metadata.crate_name));
333        s.push_str(&format!("DESC {}\n", data.metadata.description));
334        let n = data.n_particles();
335        s.push_str(&format!("N_PARTICLES {n}\n"));
336        s.push_str("BEGIN_POSITIONS\n");
337        for p in &data.positions {
338            s.push_str(&format!("{:.6} {:.6} {:.6}\n", p[0], p[1], p[2]));
339        }
340        s.push_str("END_POSITIONS\n");
341        s.push_str("BEGIN_VELOCITIES\n");
342        for v in &data.velocities {
343            s.push_str(&format!("{:.6} {:.6} {:.6}\n", v[0], v[1], v[2]));
344        }
345        s.push_str("END_VELOCITIES\n");
346        s.push_str("BEGIN_FORCES\n");
347        for f in &data.forces {
348            s.push_str(&format!("{:.6} {:.6} {:.6}\n", f[0], f[1], f[2]));
349        }
350        s.push_str("END_FORCES\n");
351        s.push_str("BEGIN_MASSES\n");
352        for m in &data.masses {
353            s.push_str(&format!("{:.6}\n", m));
354        }
355        s.push_str("END_MASSES\n");
356        s.push_str("BEGIN_TYPES\n");
357        for t in &data.types {
358            s.push_str(&format!("{t}\n"));
359        }
360        s.push_str("END_TYPES\n");
361        s.push_str("BEGIN_BOX\n");
362        for row in &data.box_matrix {
363            s.push_str(&format!("{:.6} {:.6} {:.6}\n", row[0], row[1], row[2]));
364        }
365        s.push_str("END_BOX\n");
366        s
367    }
368
369    /// Encode `data` as a JSON string (hand-rolled, no external dependency).
370    pub fn write_json(data: &RestartData) -> String {
371        let mut s = String::new();
372        s.push_str("{\n");
373        s.push_str(&format!(
374            "  \"version\": \"{}\",\n  \"timestamp\": {},\n  \"step\": {},\n  \"time\": {:.6},\n  \"crate\": \"{}\",\n  \"description\": \"{}\",\n",
375            data.metadata.version, data.metadata.timestamp, data.metadata.step,
376            data.metadata.time, data.metadata.crate_name, data.metadata.description
377        ));
378        s.push_str(&format!("  \"n_particles\": {},\n", data.n_particles()));
379        // positions
380        s.push_str("  \"positions\": [");
381        for (i, p) in data.positions.iter().enumerate() {
382            if i > 0 {
383                s.push(',');
384            }
385            s.push_str(&format!("[{:.6},{:.6},{:.6}]", p[0], p[1], p[2]));
386        }
387        s.push_str("],\n");
388        // velocities
389        s.push_str("  \"velocities\": [");
390        for (i, v) in data.velocities.iter().enumerate() {
391            if i > 0 {
392                s.push(',');
393            }
394            s.push_str(&format!("[{:.6},{:.6},{:.6}]", v[0], v[1], v[2]));
395        }
396        s.push_str("],\n");
397        // masses
398        s.push_str("  \"masses\": [");
399        for (i, m) in data.masses.iter().enumerate() {
400            if i > 0 {
401                s.push(',');
402            }
403            s.push_str(&format!("{:.6}", m));
404        }
405        s.push_str("],\n");
406        // types
407        s.push_str("  \"types\": [");
408        for (i, t) in data.types.iter().enumerate() {
409            if i > 0 {
410                s.push(',');
411            }
412            s.push_str(&format!("{t}"));
413        }
414        s.push_str("]\n}\n");
415        s
416    }
417
418    /// Encode `data` using a simplified HDF5-like tagged binary layout.
419    pub fn write_hdf5like(data: &RestartData) -> Vec<u8> {
420        // Reuse the binary writer with an alternate magic to mark it as hdf5-like.
421        let mut buf = Self::write_binary(data);
422        // Overwrite magic bytes to distinguish from plain Binary.
423        buf[0] = 0x4F; // 'O'
424        buf[1] = 0x58; // 'X'
425        buf[2] = 0x48; // 'H'
426        buf[3] = 0x35; // '5'
427        buf
428    }
429
430    /// Encode `data` using a simplified MessagePack-like binary encoding.
431    pub fn write_msgpack(data: &RestartData) -> Vec<u8> {
432        // Reuse the binary writer with an alternate magic.
433        let mut buf = Self::write_binary(data);
434        buf[0] = 0x4D; // 'M'
435        buf[1] = 0x50; // 'P'
436        buf[2] = 0x4B; // 'K'
437        buf[3] = 0x31; // '1'
438        buf
439    }
440}
441
442// ── RestartReader ─────────────────────────────────────────────────────────────
443
444/// Reads simulation restart data from files or in-memory buffers.
445#[derive(Debug, Clone)]
446pub struct RestartReader {
447    /// Input file path.
448    pub path: String,
449}
450
451impl RestartReader {
452    /// Create a new `RestartReader` for the given file path.
453    pub fn new(path: &str) -> Self {
454        Self {
455            path: path.to_string(),
456        }
457    }
458
459    /// Read and parse a restart file from `self.path`.
460    pub fn read(&self) -> Result<RestartData, String> {
461        let bytes = std::fs::read(&self.path)
462            .map_err(|e| format!("failed to read '{}': {e}", self.path))?;
463        let fmt = Self::detect_format(&bytes);
464        match fmt {
465            RestartFormat::Ascii => {
466                let text = std::str::from_utf8(&bytes).map_err(|e| format!("UTF-8 error: {e}"))?;
467                Self::read_ascii(text)
468            }
469            _ => Self::read_binary(&bytes),
470        }
471    }
472
473    /// Detect the serialization format by inspecting magic bytes.
474    pub fn detect_format(bytes: &[u8]) -> RestartFormat {
475        if bytes.len() < 4 {
476            return RestartFormat::Ascii;
477        }
478        match &bytes[0..4] {
479            b"OXRS" => RestartFormat::Binary,
480            b"OXHS" | [0x4F, 0x58, 0x48, 0x35] => RestartFormat::Hdf5Like,
481            [0x4D, 0x50, 0x4B, 0x31] => RestartFormat::MessagePack,
482            b"# Ox" | b"VERS" => RestartFormat::Ascii,
483            _ if bytes.starts_with(b"{") => RestartFormat::Json,
484            _ => RestartFormat::Ascii,
485        }
486    }
487
488    /// Parse a `RestartData` from a raw binary blob (any of the binary formats).
489    pub fn read_binary(bytes: &[u8]) -> Result<RestartData, String> {
490        let mut off = 0usize;
491        // magic (4) + version (4)
492        if bytes.len() < 8 {
493            return Err("binary too short".into());
494        }
495        off += 4; // skip magic
496        let _file_version = decode_u32(bytes, &mut off)?;
497        // metadata
498        let version = decode_str(bytes, &mut off)?;
499        let timestamp = decode_u64(bytes, &mut off)?;
500        let step = decode_u64(bytes, &mut off)?;
501        let time = decode_f64(bytes, &mut off)?;
502        let crate_name = decode_str(bytes, &mut off)?;
503        let description = decode_str(bytes, &mut off)?;
504        let metadata = RestartMetadata {
505            version,
506            timestamp,
507            step,
508            time,
509            crate_name,
510            description,
511        };
512        // particle count
513        let n = decode_u64(bytes, &mut off)? as usize;
514        let mut positions = Vec::with_capacity(n);
515        for _ in 0..n {
516            positions.push(decode_vec3(bytes, &mut off)?);
517        }
518        let mut velocities = Vec::with_capacity(n);
519        for _ in 0..n {
520            velocities.push(decode_vec3(bytes, &mut off)?);
521        }
522        let mut forces = Vec::with_capacity(n);
523        for _ in 0..n {
524            forces.push(decode_vec3(bytes, &mut off)?);
525        }
526        let mut masses = Vec::with_capacity(n);
527        for _ in 0..n {
528            masses.push(decode_f64(bytes, &mut off)?);
529        }
530        let mut types = Vec::with_capacity(n);
531        for _ in 0..n {
532            types.push(decode_u32(bytes, &mut off)?);
533        }
534        // box matrix
535        let mut box_matrix = [[0.0f64; 3]; 3];
536        for row in &mut box_matrix {
537            for c in row.iter_mut() {
538                *c = decode_f64(bytes, &mut off)?;
539            }
540        }
541        // extra scalars
542        let n_es = decode_u32(bytes, &mut off)? as usize;
543        let mut extra_scalars = Vec::with_capacity(n_es);
544        for _ in 0..n_es {
545            let name = decode_str(bytes, &mut off)?;
546            let count = decode_u64(bytes, &mut off)? as usize;
547            let mut vals = Vec::with_capacity(count);
548            for _ in 0..count {
549                vals.push(decode_f64(bytes, &mut off)?);
550            }
551            extra_scalars.push((name, vals));
552        }
553        // extra vectors
554        let n_ev = decode_u32(bytes, &mut off)? as usize;
555        let mut extra_vectors = Vec::with_capacity(n_ev);
556        for _ in 0..n_ev {
557            let name = decode_str(bytes, &mut off)?;
558            let count = decode_u64(bytes, &mut off)? as usize;
559            let mut vecs = Vec::with_capacity(count);
560            for _ in 0..count {
561                vecs.push(decode_vec3(bytes, &mut off)?);
562            }
563            extra_vectors.push((name, vecs));
564        }
565        Ok(RestartData {
566            metadata,
567            positions,
568            velocities,
569            forces,
570            masses,
571            types,
572            box_matrix,
573            extra_scalars,
574            extra_vectors,
575        })
576    }
577
578    /// Parse a `RestartData` from an ASCII restart text.
579    pub fn read_ascii(text: &str) -> Result<RestartData, String> {
580        let mut version = String::new();
581        let mut timestamp = 0u64;
582        let mut step = 0u64;
583        let mut time = 0.0f64;
584        let mut crate_name = String::new();
585        let mut description = String::new();
586        let mut positions: Vec<[f64; 3]> = Vec::new();
587        let mut velocities: Vec<[f64; 3]> = Vec::new();
588        let mut forces: Vec<[f64; 3]> = Vec::new();
589        let mut masses: Vec<f64> = Vec::new();
590        let mut types: Vec<u32> = Vec::new();
591        let mut box_matrix = [[0.0f64; 3]; 3];
592
593        #[derive(PartialEq)]
594        enum Section {
595            None,
596            Positions,
597            Velocities,
598            Forces,
599            Masses,
600            Types,
601            Box,
602        }
603        let mut section = Section::None;
604        let mut box_row = 0usize;
605
606        for line in text.lines() {
607            let line = line.trim();
608            if line.is_empty() || line.starts_with('#') {
609                continue;
610            }
611            match line {
612                "BEGIN_POSITIONS" => {
613                    section = Section::Positions;
614                    continue;
615                }
616                "END_POSITIONS" => {
617                    section = Section::None;
618                    continue;
619                }
620                "BEGIN_VELOCITIES" => {
621                    section = Section::Velocities;
622                    continue;
623                }
624                "END_VELOCITIES" => {
625                    section = Section::None;
626                    continue;
627                }
628                "BEGIN_FORCES" => {
629                    section = Section::Forces;
630                    continue;
631                }
632                "END_FORCES" => {
633                    section = Section::None;
634                    continue;
635                }
636                "BEGIN_MASSES" => {
637                    section = Section::Masses;
638                    continue;
639                }
640                "END_MASSES" => {
641                    section = Section::None;
642                    continue;
643                }
644                "BEGIN_TYPES" => {
645                    section = Section::Types;
646                    continue;
647                }
648                "END_TYPES" => {
649                    section = Section::None;
650                    continue;
651                }
652                "BEGIN_BOX" => {
653                    section = Section::Box;
654                    box_row = 0;
655                    continue;
656                }
657                "END_BOX" => {
658                    section = Section::None;
659                    continue;
660                }
661                _ => {}
662            }
663            match section {
664                Section::Positions | Section::Velocities | Section::Forces => {
665                    let nums: Vec<f64> = line
666                        .split_whitespace()
667                        .map(|s| s.parse::<f64>().unwrap_or(0.0))
668                        .collect();
669                    if nums.len() >= 3 {
670                        let arr = [nums[0], nums[1], nums[2]];
671                        match section {
672                            Section::Positions => positions.push(arr),
673                            Section::Velocities => velocities.push(arr),
674                            Section::Forces => forces.push(arr),
675                            _ => {}
676                        }
677                    }
678                }
679                Section::Masses => {
680                    if let Ok(m) = line.parse::<f64>() {
681                        masses.push(m);
682                    }
683                }
684                Section::Types => {
685                    if let Ok(t) = line.parse::<u32>() {
686                        types.push(t);
687                    }
688                }
689                Section::Box => {
690                    if box_row < 3 {
691                        let nums: Vec<f64> = line
692                            .split_whitespace()
693                            .map(|s| s.parse::<f64>().unwrap_or(0.0))
694                            .collect();
695                        if nums.len() >= 3 {
696                            box_matrix[box_row] = [nums[0], nums[1], nums[2]];
697                            box_row += 1;
698                        }
699                    }
700                }
701                Section::None => {
702                    // key-value header lines
703                    if let Some(rest) = line.strip_prefix("VERSION ") {
704                        version = rest.trim().to_string();
705                    } else if let Some(rest) = line.strip_prefix("TIMESTAMP ") {
706                        timestamp = rest.trim().parse().unwrap_or(0);
707                    } else if let Some(rest) = line.strip_prefix("STEP ") {
708                        step = rest.trim().parse().unwrap_or(0);
709                    } else if let Some(rest) = line.strip_prefix("TIME ") {
710                        time = rest.trim().parse().unwrap_or(0.0);
711                    } else if let Some(rest) = line.strip_prefix("CRATE ") {
712                        crate_name = rest.trim().to_string();
713                    } else if let Some(rest) = line.strip_prefix("DESC ") {
714                        description = rest.trim().to_string();
715                    }
716                }
717            }
718        }
719        let metadata = RestartMetadata {
720            version,
721            timestamp,
722            step,
723            time,
724            crate_name,
725            description,
726        };
727        Ok(RestartData {
728            metadata,
729            positions,
730            velocities,
731            forces,
732            masses,
733            types,
734            box_matrix,
735            extra_scalars: Vec::new(),
736            extra_vectors: Vec::new(),
737        })
738    }
739}
740
741// ── CheckpointManager ─────────────────────────────────────────────────────────
742
743/// Manages a rolling set of checkpoint files in a directory.
744#[derive(Debug, Clone)]
745pub struct CheckpointManager {
746    /// Base directory where checkpoint files are stored.
747    pub base_dir: String,
748    /// Maximum number of checkpoints to retain.
749    pub max_checkpoints: usize,
750    /// In-memory list of `(step, file_path)` known checkpoints.
751    checkpoints: Vec<(u64, String)>,
752}
753
754impl CheckpointManager {
755    /// Create a new manager rooted at `base_dir`, keeping at most `max_checkpoints` files.
756    pub fn new(base_dir: &str, max_checkpoints: usize) -> Self {
757        Self {
758            base_dir: base_dir.to_string(),
759            max_checkpoints: max_checkpoints.max(1),
760            checkpoints: Vec::new(),
761        }
762    }
763
764    /// Serialize `data` and write it as `checkpoint_`step`.bin` under `base_dir`.
765    ///
766    /// Returns the full path of the newly written file.
767    pub fn save_checkpoint(&mut self, data: &RestartData, step: u64) -> String {
768        let filename = format!("{}/checkpoint_{step:010}.bin", self.base_dir);
769        let writer = RestartWriter::new(&filename, RestartFormat::Binary);
770        let _ = std::fs::create_dir_all(&self.base_dir);
771        let _ = writer.write(data);
772        self.checkpoints.push((step, filename.clone()));
773        self.prune_old_checkpoints();
774        filename
775    }
776
777    /// Load and return the most recent checkpoint, or `None` if none exist.
778    pub fn load_latest(&self) -> Option<RestartData> {
779        let (_, path) = self.checkpoints.last()?;
780        let reader = RestartReader::new(path);
781        reader.read().ok()
782    }
783
784    /// Return all known checkpoints as `(step, path)` pairs, oldest first.
785    pub fn list_checkpoints(&self) -> Vec<(u64, String)> {
786        self.checkpoints.clone()
787    }
788
789    /// Remove the oldest checkpoints so that at most `max_checkpoints` remain.
790    pub fn prune_old_checkpoints(&mut self) {
791        while self.checkpoints.len() > self.max_checkpoints {
792            let (_, path) = self.checkpoints.remove(0);
793            let _ = std::fs::remove_file(&path);
794        }
795    }
796}
797
798// ── IncrementalRestart ────────────────────────────────────────────────────────
799
800/// Tracks which particle indices have changed since the last checkpoint.
801///
802/// Only changed particles are serialized, reducing I/O cost for large systems
803/// with sparse updates.
804#[derive(Debug, Clone, Default)]
805pub struct IncrementalRestart {
806    /// Indices of particles that have been marked as modified.
807    pub changed: Vec<usize>,
808}
809
810impl IncrementalRestart {
811    /// Create an empty tracker.
812    pub fn new() -> Self {
813        Self::default()
814    }
815
816    /// Mark particle `idx` as changed.
817    pub fn mark_changed(&mut self, idx: usize) {
818        if !self.changed.contains(&idx) {
819            self.changed.push(idx);
820        }
821    }
822
823    /// Clear the changed set (call after a checkpoint is written).
824    pub fn reset(&mut self) {
825        self.changed.clear();
826    }
827
828    /// Extract only the changed particles from `full` into a new `RestartData`.
829    ///
830    /// Metadata is copied from `full`; all other arrays are filtered to `changed` indices.
831    pub fn extract_delta(&self, full: &RestartData) -> RestartData {
832        let indices = &self.changed;
833        let positions = indices
834            .iter()
835            .filter_map(|&i| full.positions.get(i).copied())
836            .collect();
837        let velocities = indices
838            .iter()
839            .filter_map(|&i| full.velocities.get(i).copied())
840            .collect();
841        let forces = indices
842            .iter()
843            .filter_map(|&i| full.forces.get(i).copied())
844            .collect();
845        let masses = indices
846            .iter()
847            .filter_map(|&i| full.masses.get(i).copied())
848            .collect();
849        let types = indices
850            .iter()
851            .filter_map(|&i| full.types.get(i).copied())
852            .collect();
853        RestartData {
854            metadata: full.metadata.clone(),
855            positions,
856            velocities,
857            forces,
858            masses,
859            types,
860            box_matrix: full.box_matrix,
861            extra_scalars: Vec::new(),
862            extra_vectors: Vec::new(),
863        }
864    }
865}
866
867// ── RestartValidator ──────────────────────────────────────────────────────────
868
869/// Computes and verifies simple checksums over restart binary blobs.
870#[derive(Debug, Clone, Default)]
871pub struct RestartValidator;
872
873impl RestartValidator {
874    /// Create a new validator.
875    pub fn new() -> Self {
876        Self
877    }
878
879    /// Compute a simple additive sum checksum over `bytes`.
880    pub fn checksum_sum(bytes: &[u8]) -> u64 {
881        bytes.iter().map(|&b| b as u64).sum()
882    }
883
884    /// Compute a byte-wise XOR checksum over `bytes`.
885    pub fn checksum_xor(bytes: &[u8]) -> u8 {
886        bytes.iter().fold(0u8, |acc, &b| acc ^ b)
887    }
888
889    /// Verify that `bytes` have not changed since `expected_sum` was computed.
890    pub fn verify_sum(bytes: &[u8], expected_sum: u64) -> bool {
891        Self::checksum_sum(bytes) == expected_sum
892    }
893
894    /// Verify that `bytes` have not changed since `expected_xor` was computed.
895    pub fn verify_xor(bytes: &[u8], expected_xor: u8) -> bool {
896        Self::checksum_xor(bytes) == expected_xor
897    }
898}
899
900// ── Conversion utilities ──────────────────────────────────────────────────────
901
902/// Convert a `RestartData` snapshot to an XYZ-format string.
903///
904/// The XYZ file contains a single frame with element symbol "XX" for all particles.
905pub fn restart_to_xyz(data: &RestartData) -> String {
906    let n = data.n_particles();
907    let mut s = String::new();
908    s.push_str(&format!("{n}\n"));
909    s.push_str(&format!(
910        "Restart step={} time={:.6}\n",
911        data.metadata.step, data.metadata.time
912    ));
913    for (i, p) in data.positions.iter().enumerate() {
914        let sym = if let Some(&t) = data.types.get(i) {
915            match t {
916                0 => "H",
917                1 => "C",
918                2 => "N",
919                3 => "O",
920                _ => "X",
921            }
922        } else {
923            "X"
924        };
925        s.push_str(&format!("{sym} {:.6} {:.6} {:.6}\n", p[0], p[1], p[2]));
926    }
927    s
928}
929
930/// Convert a `RestartData` snapshot to a LAMMPS dump-format string.
931///
932/// Fields: `id type x y z vx vy vz fx fy fz mass`
933pub fn restart_to_lammps_dump(data: &RestartData) -> String {
934    let n = data.n_particles();
935    let step = data.metadata.step;
936    let mut s = String::new();
937    s.push_str(&format!("ITEM: TIMESTEP\n{step}\n"));
938    s.push_str(&format!("ITEM: NUMBER OF ATOMS\n{n}\n"));
939    let bm = &data.box_matrix;
940    s.push_str("ITEM: BOX BOUNDS pp pp pp\n");
941    s.push_str(&format!(
942        "0.0 {:.6}\n0.0 {:.6}\n0.0 {:.6}\n",
943        bm[0][0], bm[1][1], bm[2][2]
944    ));
945    s.push_str("ITEM: ATOMS id type x y z vx vy vz fx fy fz mass\n");
946    for i in 0..n {
947        let p = data.positions.get(i).copied().unwrap_or([0.0; 3]);
948        let v = data.velocities.get(i).copied().unwrap_or([0.0; 3]);
949        let f = data.forces.get(i).copied().unwrap_or([0.0; 3]);
950        let m = data.masses.get(i).copied().unwrap_or(1.0);
951        let tp = data.types.get(i).copied().unwrap_or(0);
952        s.push_str(&format!(
953            "{} {} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6} {:.6}\n",
954            i + 1,
955            tp,
956            p[0],
957            p[1],
958            p[2],
959            v[0],
960            v[1],
961            v[2],
962            f[0],
963            f[1],
964            f[2],
965            m
966        ));
967    }
968    s
969}
970
971// ── Tests ─────────────────────────────────────────────────────────────────────
972
973#[cfg(test)]
974mod tests {
975    use super::*;
976
977    fn make_data(n: usize) -> RestartData {
978        let meta = RestartMetadata::new("1.0", 1234567890, 42, 4.2, "oxiphysics-md", "unit test");
979        let positions = (0..n).map(|i| [i as f64, i as f64 * 0.5, 0.0]).collect();
980        let velocities = (0..n).map(|i| [0.1 * i as f64, 0.0, 0.0]).collect();
981        let forces = (0..n).map(|_| [0.0, -9.81, 0.0]).collect();
982        let masses = (0..n).map(|_| 1.0).collect();
983        let types = (0..n).map(|i| (i % 3) as u32).collect();
984        let box_matrix = [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]];
985        RestartData {
986            metadata: meta,
987            positions,
988            velocities,
989            forces,
990            masses,
991            types,
992            box_matrix,
993            extra_scalars: Vec::new(),
994            extra_vectors: Vec::new(),
995        }
996    }
997
998    // -- RestartMetadata --
999
1000    #[test]
1001    fn test_metadata_new_fields() {
1002        let m = RestartMetadata::new("2.0", 100, 5, 0.5, "crate_x", "desc");
1003        assert_eq!(m.version, "2.0");
1004        assert_eq!(m.timestamp, 100);
1005        assert_eq!(m.step, 5);
1006        assert!((m.time - 0.5).abs() < 1e-12);
1007        assert_eq!(m.crate_name, "crate_x");
1008        assert_eq!(m.description, "desc");
1009    }
1010
1011    #[test]
1012    fn test_metadata_default_test() {
1013        let m = RestartMetadata::default_test();
1014        assert_eq!(m.version, "1.0");
1015        assert_eq!(m.step, 0);
1016    }
1017
1018    // -- RestartData --
1019
1020    #[test]
1021    fn test_restart_data_n_particles() {
1022        let d = make_data(7);
1023        assert_eq!(d.n_particles(), 7);
1024    }
1025
1026    #[test]
1027    fn test_restart_data_empty() {
1028        let d = RestartData::empty(RestartMetadata::default_test());
1029        assert_eq!(d.n_particles(), 0);
1030        assert!(d.positions.is_empty());
1031    }
1032
1033    #[test]
1034    fn test_restart_data_single_particle_test() {
1035        let d = RestartData::single_particle_test();
1036        assert_eq!(d.n_particles(), 1);
1037        assert!((d.positions[0][0] - 1.0).abs() < 1e-12);
1038    }
1039
1040    // -- Binary round-trip --
1041
1042    #[test]
1043    fn test_binary_roundtrip_zero_particles() {
1044        let d = make_data(0);
1045        let bytes = RestartWriter::write_binary(&d);
1046        let recovered = RestartReader::read_binary(&bytes).unwrap();
1047        assert_eq!(recovered.n_particles(), 0);
1048        assert_eq!(recovered.metadata.step, 42);
1049    }
1050
1051    #[test]
1052    fn test_binary_roundtrip_positions() {
1053        let d = make_data(5);
1054        let bytes = RestartWriter::write_binary(&d);
1055        let r = RestartReader::read_binary(&bytes).unwrap();
1056        for i in 0..5 {
1057            assert!((r.positions[i][0] - d.positions[i][0]).abs() < 1e-12);
1058            assert!((r.positions[i][1] - d.positions[i][1]).abs() < 1e-12);
1059        }
1060    }
1061
1062    #[test]
1063    fn test_binary_roundtrip_metadata() {
1064        let d = make_data(3);
1065        let bytes = RestartWriter::write_binary(&d);
1066        let r = RestartReader::read_binary(&bytes).unwrap();
1067        assert_eq!(r.metadata.version, "1.0");
1068        assert_eq!(r.metadata.timestamp, 1234567890);
1069        assert_eq!(r.metadata.step, 42);
1070        assert!((r.metadata.time - 4.2).abs() < 1e-10);
1071        assert_eq!(r.metadata.crate_name, "oxiphysics-md");
1072    }
1073
1074    #[test]
1075    fn test_binary_roundtrip_box_matrix() {
1076        let d = make_data(2);
1077        let bytes = RestartWriter::write_binary(&d);
1078        let r = RestartReader::read_binary(&bytes).unwrap();
1079        assert!((r.box_matrix[0][0] - 10.0).abs() < 1e-12);
1080        assert!((r.box_matrix[1][1] - 10.0).abs() < 1e-12);
1081        assert!((r.box_matrix[2][2] - 10.0).abs() < 1e-12);
1082        assert!(r.box_matrix[0][1].abs() < 1e-12);
1083    }
1084
1085    #[test]
1086    fn test_binary_roundtrip_types() {
1087        let d = make_data(6);
1088        let bytes = RestartWriter::write_binary(&d);
1089        let r = RestartReader::read_binary(&bytes).unwrap();
1090        for i in 0..6 {
1091            assert_eq!(r.types[i], (i % 3) as u32);
1092        }
1093    }
1094
1095    #[test]
1096    fn test_binary_roundtrip_extra_scalars() {
1097        let mut d = make_data(3);
1098        d.extra_scalars
1099            .push(("charge".into(), vec![0.1, -0.2, 0.3]));
1100        let bytes = RestartWriter::write_binary(&d);
1101        let r = RestartReader::read_binary(&bytes).unwrap();
1102        assert_eq!(r.extra_scalars.len(), 1);
1103        assert_eq!(r.extra_scalars[0].0, "charge");
1104        assert!((r.extra_scalars[0].1[1] - (-0.2)).abs() < 1e-12);
1105    }
1106
1107    #[test]
1108    fn test_binary_roundtrip_extra_vectors() {
1109        let mut d = make_data(2);
1110        d.extra_vectors
1111            .push(("spin".into(), vec![[0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]));
1112        let bytes = RestartWriter::write_binary(&d);
1113        let r = RestartReader::read_binary(&bytes).unwrap();
1114        assert_eq!(r.extra_vectors.len(), 1);
1115        assert_eq!(r.extra_vectors[0].0, "spin");
1116        assert!((r.extra_vectors[0].1[0][2] - 1.0).abs() < 1e-12);
1117    }
1118
1119    #[test]
1120    fn test_binary_magic_bytes() {
1121        let d = make_data(1);
1122        let bytes = RestartWriter::write_binary(&d);
1123        assert_eq!(&bytes[0..4], b"OXRS");
1124    }
1125
1126    // -- ASCII round-trip --
1127
1128    #[test]
1129    fn test_ascii_roundtrip_basic() {
1130        let d = make_data(4);
1131        let text = RestartWriter::write_ascii(&d);
1132        let r = RestartReader::read_ascii(&text).unwrap();
1133        assert_eq!(r.n_particles(), 4);
1134        assert_eq!(r.metadata.step, 42);
1135    }
1136
1137    #[test]
1138    fn test_ascii_roundtrip_positions() {
1139        let d = make_data(3);
1140        let text = RestartWriter::write_ascii(&d);
1141        let r = RestartReader::read_ascii(&text).unwrap();
1142        for i in 0..3 {
1143            assert!((r.positions[i][0] - d.positions[i][0]).abs() < 1e-4);
1144        }
1145    }
1146
1147    #[test]
1148    fn test_ascii_roundtrip_box() {
1149        let d = make_data(1);
1150        let text = RestartWriter::write_ascii(&d);
1151        let r = RestartReader::read_ascii(&text).unwrap();
1152        assert!((r.box_matrix[0][0] - 10.0).abs() < 1e-4);
1153    }
1154
1155    #[test]
1156    fn test_ascii_contains_keywords() {
1157        let d = make_data(2);
1158        let text = RestartWriter::write_ascii(&d);
1159        assert!(text.contains("BEGIN_POSITIONS"));
1160        assert!(text.contains("END_POSITIONS"));
1161        assert!(text.contains("VERSION"));
1162        assert!(text.contains("STEP"));
1163    }
1164
1165    // -- JSON --
1166
1167    #[test]
1168    fn test_json_contains_particles_key() {
1169        let d = make_data(3);
1170        let j = RestartWriter::write_json(&d);
1171        assert!(j.contains("n_particles"));
1172        assert!(j.contains("positions"));
1173        assert!(j.contains("velocities"));
1174    }
1175
1176    #[test]
1177    fn test_json_step_present() {
1178        let d = make_data(1);
1179        let j = RestartWriter::write_json(&d);
1180        assert!(j.contains("\"step\": 42"));
1181    }
1182
1183    // -- detect_format --
1184
1185    #[test]
1186    fn test_detect_format_binary() {
1187        let d = make_data(1);
1188        let bytes = RestartWriter::write_binary(&d);
1189        assert_eq!(RestartReader::detect_format(&bytes), RestartFormat::Binary);
1190    }
1191
1192    #[test]
1193    fn test_detect_format_ascii() {
1194        let text = "# OxiPhysics restart\nVERSION 1.0\n";
1195        assert_eq!(
1196            RestartReader::detect_format(text.as_bytes()),
1197            RestartFormat::Ascii
1198        );
1199    }
1200
1201    #[test]
1202    fn test_detect_format_json() {
1203        let j = "{\"n_particles\": 0}";
1204        assert_eq!(
1205            RestartReader::detect_format(j.as_bytes()),
1206            RestartFormat::Json
1207        );
1208    }
1209
1210    // -- Validator --
1211
1212    #[test]
1213    fn test_validator_sum_consistent() {
1214        let bytes = b"hello world";
1215        let sum = RestartValidator::checksum_sum(bytes);
1216        assert!(RestartValidator::verify_sum(bytes, sum));
1217    }
1218
1219    #[test]
1220    fn test_validator_xor_consistent() {
1221        let bytes = b"test data";
1222        let xor = RestartValidator::checksum_xor(bytes);
1223        assert!(RestartValidator::verify_xor(bytes, xor));
1224    }
1225
1226    #[test]
1227    fn test_validator_sum_detects_corruption() {
1228        let bytes = b"original";
1229        let sum = RestartValidator::checksum_sum(bytes);
1230        let corrupt = b"0riginal";
1231        assert!(!RestartValidator::verify_sum(corrupt, sum));
1232    }
1233
1234    #[test]
1235    fn test_validator_xor_empty() {
1236        let xor = RestartValidator::checksum_xor(b"");
1237        assert_eq!(xor, 0);
1238    }
1239
1240    // -- IncrementalRestart --
1241
1242    #[test]
1243    fn test_incremental_mark_and_extract() {
1244        let full = make_data(5);
1245        let mut inc = IncrementalRestart::new();
1246        inc.mark_changed(1);
1247        inc.mark_changed(3);
1248        let delta = inc.extract_delta(&full);
1249        assert_eq!(delta.n_particles(), 2);
1250        assert!((delta.positions[0][0] - full.positions[1][0]).abs() < 1e-12);
1251    }
1252
1253    #[test]
1254    fn test_incremental_reset_clears_changes() {
1255        let mut inc = IncrementalRestart::new();
1256        inc.mark_changed(0);
1257        inc.mark_changed(2);
1258        inc.reset();
1259        assert!(inc.changed.is_empty());
1260    }
1261
1262    #[test]
1263    fn test_incremental_no_duplicates() {
1264        let mut inc = IncrementalRestart::new();
1265        inc.mark_changed(0);
1266        inc.mark_changed(0);
1267        inc.mark_changed(0);
1268        assert_eq!(inc.changed.len(), 1);
1269    }
1270
1271    // -- Conversion utilities --
1272
1273    #[test]
1274    fn test_restart_to_xyz_header_line_count() {
1275        let d = make_data(4);
1276        let xyz = restart_to_xyz(&d);
1277        let lines: Vec<&str> = xyz.lines().collect();
1278        assert!(lines.len() >= 6); // count + comment + 4 atoms
1279        assert_eq!(lines[0], "4");
1280    }
1281
1282    #[test]
1283    fn test_restart_to_lammps_dump_has_timestep() {
1284        let d = make_data(2);
1285        let dump = restart_to_lammps_dump(&d);
1286        assert!(dump.contains("ITEM: TIMESTEP"));
1287        assert!(dump.contains("ITEM: NUMBER OF ATOMS"));
1288        assert!(dump.contains("ITEM: ATOMS"));
1289    }
1290
1291    #[test]
1292    fn test_restart_to_lammps_dump_atom_count() {
1293        let d = make_data(3);
1294        let dump = restart_to_lammps_dump(&d);
1295        let atom_lines: Vec<&str> = dump
1296            .lines()
1297            .skip_while(|l| !l.starts_with("ITEM: ATOMS"))
1298            .skip(1)
1299            .collect();
1300        assert_eq!(atom_lines.len(), 3);
1301    }
1302
1303    // -- CheckpointManager (in-memory only, uses /tmp) --
1304
1305    #[test]
1306    fn test_checkpoint_manager_list_empty() {
1307        let mgr = CheckpointManager::new("/tmp/oxi_test_ckpt_empty", 3);
1308        assert!(mgr.list_checkpoints().is_empty());
1309    }
1310
1311    #[test]
1312    fn test_checkpoint_manager_prune_keeps_max() {
1313        let dir = "/tmp/oxi_test_ckpt_prune";
1314        let _ = std::fs::remove_dir_all(dir);
1315        let mut mgr = CheckpointManager::new(dir, 2);
1316        let d = make_data(1);
1317        mgr.save_checkpoint(&d, 1);
1318        mgr.save_checkpoint(&d, 2);
1319        mgr.save_checkpoint(&d, 3);
1320        assert_eq!(mgr.list_checkpoints().len(), 2);
1321        // Latest should be step 3.
1322        let latest = mgr.list_checkpoints().last().unwrap().0;
1323        assert_eq!(latest, 3);
1324        let _ = std::fs::remove_dir_all(dir);
1325    }
1326}