Skip to main content

oxiphysics_io/
binary_formats.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Binary format I/O for physics simulation data.
6//!
7//! Provides binary trajectory writers/readers, DCD format support,
8//! minimal XTC-like compression, energy log writers, and checkpoint I/O.
9
10#![allow(dead_code)]
11#![allow(clippy::too_many_arguments)]
12
13// ── BinaryHeader ─────────────────────────────────────────────────────────────
14
15/// Magic bytes identifying the OxiPhysics binary trajectory format.
16pub const OXIPHY_MAGIC: [u8; 4] = *b"OXIP";
17
18/// Binary file header for OxiPhysics trajectory files.
19///
20/// Layout (little-endian):
21/// - 4 bytes  : magic
22/// - 4 bytes  : version (u32)
23/// - 8 bytes  : n_particles (u64)
24/// - 8 bytes  : n_frames (u64)
25/// - 8 bytes  : dt (f64)
26#[derive(Debug, Clone, PartialEq)]
27pub struct BinaryHeader {
28    /// Magic bytes — must equal `OXIPHY_MAGIC`.
29    pub magic: [u8; 4],
30    /// File format version.
31    pub version: u32,
32    /// Number of particles per frame.
33    pub n_particles: u64,
34    /// Number of frames stored.
35    pub n_frames: u64,
36    /// Integration time step (seconds).
37    pub dt: f64,
38}
39
40impl BinaryHeader {
41    /// Byte size of the serialised header.
42    pub const SIZE: usize = 32;
43
44    /// Create a new header with the given parameters.
45    pub fn new(n_particles: u64, n_frames: u64, dt: f64) -> Self {
46        Self {
47            magic: OXIPHY_MAGIC,
48            version: 1,
49            n_particles,
50            n_frames,
51            dt,
52        }
53    }
54
55    /// Validate that the magic bytes and version are recognised.
56    ///
57    /// Returns `true` when the header looks valid.
58    pub fn validate(&self) -> bool {
59        self.magic == OXIPHY_MAGIC && self.version >= 1
60    }
61
62    /// Serialise this header into `buf` (appends bytes).
63    pub fn write(&self, buf: &mut Vec<u8>) {
64        buf.extend_from_slice(&self.magic);
65        buf.extend_from_slice(&self.version.to_le_bytes());
66        buf.extend_from_slice(&self.n_particles.to_le_bytes());
67        buf.extend_from_slice(&self.n_frames.to_le_bytes());
68        buf.extend_from_slice(&self.dt.to_le_bytes());
69    }
70
71    /// Deserialise a header from a byte slice.
72    ///
73    /// Returns `None` if `buf` is shorter than [`BinaryHeader::SIZE`].
74    pub fn read(buf: &[u8]) -> Option<Self> {
75        if buf.len() < Self::SIZE {
76            return None;
77        }
78        let magic = [buf[0], buf[1], buf[2], buf[3]];
79        let version = u32::from_le_bytes([buf[4], buf[5], buf[6], buf[7]]);
80        let n_particles = u64::from_le_bytes(buf[8..16].try_into().ok()?);
81        let n_frames = u64::from_le_bytes(buf[16..24].try_into().ok()?);
82        let dt = f64::from_le_bytes(buf[24..32].try_into().ok()?);
83        Some(Self {
84            magic,
85            version,
86            n_particles,
87            n_frames,
88            dt,
89        })
90    }
91}
92
93// ── ParticleFrame ─────────────────────────────────────────────────────────────
94
95/// A single snapshot of particle state at one simulation step.
96#[derive(Debug, Clone, PartialEq)]
97pub struct ParticleFrame {
98    /// Simulation step index.
99    pub step: u64,
100    /// Simulation time (seconds).
101    pub time: f64,
102    /// Particle positions `[x, y, z]` in metres.
103    pub positions: Vec<[f32; 3]>,
104    /// Particle velocities `[vx, vy, vz]` in m/s.
105    pub velocities: Vec<[f32; 3]>,
106    /// Particle masses in kg.
107    pub masses: Vec<f32>,
108}
109
110impl ParticleFrame {
111    /// Create a new empty frame for `n` particles.
112    pub fn new(step: u64, time: f64, n: usize) -> Self {
113        Self {
114            step,
115            time,
116            positions: vec![[0.0; 3]; n],
117            velocities: vec![[0.0; 3]; n],
118            masses: vec![1.0; n],
119        }
120    }
121
122    /// Byte size of a serialised frame for `n` particles.
123    pub fn serialized_size(n: usize) -> usize {
124        // step(8) + time(8) + n_pos(n*12) + n_vel(n*12) + n_mass(n*4)
125        8 + 8 + n * 12 + n * 12 + n * 4
126    }
127
128    /// Serialise this frame into `buf` (appends bytes).
129    pub fn serialize(&self, buf: &mut Vec<u8>) {
130        buf.extend_from_slice(&self.step.to_le_bytes());
131        buf.extend_from_slice(&self.time.to_le_bytes());
132        for p in &self.positions {
133            buf.extend_from_slice(&p[0].to_le_bytes());
134            buf.extend_from_slice(&p[1].to_le_bytes());
135            buf.extend_from_slice(&p[2].to_le_bytes());
136        }
137        for v in &self.velocities {
138            buf.extend_from_slice(&v[0].to_le_bytes());
139            buf.extend_from_slice(&v[1].to_le_bytes());
140            buf.extend_from_slice(&v[2].to_le_bytes());
141        }
142        for &m in &self.masses {
143            buf.extend_from_slice(&m.to_le_bytes());
144        }
145    }
146
147    /// Deserialise a frame from `buf` starting at byte 0, given `n` particles.
148    ///
149    /// Returns `None` if `buf` is too short.
150    pub fn deserialize(buf: &[u8], n: usize) -> Option<Self> {
151        let needed = Self::serialized_size(n);
152        if buf.len() < needed {
153            return None;
154        }
155        let step = u64::from_le_bytes(buf[0..8].try_into().ok()?);
156        let time = f64::from_le_bytes(buf[8..16].try_into().ok()?);
157        let mut offset = 16usize;
158        let mut positions = Vec::with_capacity(n);
159        for _ in 0..n {
160            let x = f32::from_le_bytes(buf[offset..offset + 4].try_into().ok()?);
161            let y = f32::from_le_bytes(buf[offset + 4..offset + 8].try_into().ok()?);
162            let z = f32::from_le_bytes(buf[offset + 8..offset + 12].try_into().ok()?);
163            positions.push([x, y, z]);
164            offset += 12;
165        }
166        let mut velocities = Vec::with_capacity(n);
167        for _ in 0..n {
168            let x = f32::from_le_bytes(buf[offset..offset + 4].try_into().ok()?);
169            let y = f32::from_le_bytes(buf[offset + 4..offset + 8].try_into().ok()?);
170            let z = f32::from_le_bytes(buf[offset + 8..offset + 12].try_into().ok()?);
171            velocities.push([x, y, z]);
172            offset += 12;
173        }
174        let mut masses = Vec::with_capacity(n);
175        for _ in 0..n {
176            let m = f32::from_le_bytes(buf[offset..offset + 4].try_into().ok()?);
177            masses.push(m);
178            offset += 4;
179        }
180        Some(Self {
181            step,
182            time,
183            positions,
184            velocities,
185            masses,
186        })
187    }
188}
189
190// ── BinaryTrajectoryWriter ────────────────────────────────────────────────────
191
192/// Streaming writer for the OxiPhysics binary trajectory format.
193///
194/// Call [`write_frame`](BinaryTrajectoryWriter::write_frame) for each step,
195/// then [`finalize`](BinaryTrajectoryWriter::finalize) to obtain the complete
196/// byte buffer (header `n_frames` is patched at finalization).
197#[derive(Debug)]
198pub struct BinaryTrajectoryWriter {
199    buf: Vec<u8>,
200    /// Embedded file header.
201    pub header: BinaryHeader,
202    frames_written: u64,
203}
204
205impl BinaryTrajectoryWriter {
206    /// Create a writer for the given number of particles and time step.
207    pub fn new(n_particles: u64, dt: f64) -> Self {
208        let header = BinaryHeader::new(n_particles, 0, dt);
209        let mut buf = Vec::new();
210        header.write(&mut buf); // placeholder; patched in finalize
211        Self {
212            buf,
213            header,
214            frames_written: 0,
215        }
216    }
217
218    /// Append one frame to the internal buffer.
219    pub fn write_frame(&mut self, frame: &ParticleFrame) {
220        frame.serialize(&mut self.buf);
221        self.frames_written += 1;
222    }
223
224    /// Finalise: patch `n_frames` in the header and return the buffer.
225    pub fn finalize(mut self) -> Vec<u8> {
226        // Patch n_frames at byte offset 16..24 in the header
227        let nf = self.frames_written.to_le_bytes();
228        self.buf[16..24].copy_from_slice(&nf);
229        self.buf
230    }
231
232    /// Number of frames written so far.
233    pub fn frame_count(&self) -> u64 {
234        self.frames_written
235    }
236}
237
238// ── BinaryTrajectoryReader ────────────────────────────────────────────────────
239
240/// Reader for the OxiPhysics binary trajectory format.
241#[derive(Debug)]
242pub struct BinaryTrajectoryReader {
243    buf: Vec<u8>,
244    cursor: usize,
245    /// Parsed file header.
246    pub header: BinaryHeader,
247}
248
249impl BinaryTrajectoryReader {
250    /// Create a reader from a byte buffer.
251    ///
252    /// Returns `None` if the header is invalid or magic bytes mismatch.
253    pub fn new(buf: Vec<u8>) -> Option<Self> {
254        let header = BinaryHeader::read(&buf)?;
255        if !header.validate() {
256            return None;
257        }
258        Some(Self {
259            buf,
260            cursor: BinaryHeader::SIZE,
261            header,
262        })
263    }
264
265    /// Read the next frame sequentially.
266    ///
267    /// Returns `None` when no more frames are available.
268    pub fn read_frame(&mut self) -> Option<ParticleFrame> {
269        let n = self.header.n_particles as usize;
270        let needed = ParticleFrame::serialized_size(n);
271        if self.cursor + needed > self.buf.len() {
272            return None;
273        }
274        let frame = ParticleFrame::deserialize(&self.buf[self.cursor..], n)?;
275        self.cursor += needed;
276        Some(frame)
277    }
278
279    /// Seek to frame index `i` (0-based).
280    ///
281    /// Does nothing if `i` is out of range.
282    pub fn seek_frame(&mut self, i: usize) {
283        let n = self.header.n_particles as usize;
284        let frame_size = ParticleFrame::serialized_size(n);
285        let new_cursor = BinaryHeader::SIZE + i * frame_size;
286        if new_cursor <= self.buf.len() {
287            self.cursor = new_cursor;
288        }
289    }
290
291    /// Return the number of frames encoded in the header.
292    pub fn n_frames(&self) -> u64 {
293        self.header.n_frames
294    }
295}
296
297// ── DcdWriter ────────────────────────────────────────────────────────────────
298
299/// Writer that produces CHARMM DCD trajectory files as a byte buffer.
300///
301/// The DCD format uses Fortran-style 4-byte record markers before and after
302/// each block.
303#[derive(Debug)]
304pub struct DcdWriter {
305    /// Number of atoms.
306    pub n_atoms: usize,
307    /// Number of frames written so far.
308    pub n_frames: usize,
309    buf: Vec<u8>,
310    /// Byte offset of the `NFILE` field so it can be patched at the end.
311    nfile_offset: usize,
312}
313
314fn write_i32_le(buf: &mut Vec<u8>, v: i32) {
315    buf.extend_from_slice(&v.to_le_bytes());
316}
317
318fn write_f32_le(buf: &mut Vec<u8>, v: f32) {
319    buf.extend_from_slice(&v.to_le_bytes());
320}
321
322fn fortran_record(buf: &mut Vec<u8>, data: &[u8]) {
323    let len = data.len() as i32;
324    write_i32_le(buf, len);
325    buf.extend_from_slice(data);
326    write_i32_le(buf, len);
327}
328
329impl DcdWriter {
330    /// Create a new DCD writer for `n_atoms` atoms and the given time step.
331    pub fn new(n_atoms: usize, dt: f32) -> Self {
332        let mut buf = Vec::new();
333        // CORD block
334        let mut hdr = Vec::new();
335        hdr.extend_from_slice(b"CORD");
336        // NFILE placeholder (patched later)
337        let nfile_offset_in_hdr = hdr.len();
338        write_i32_le(&mut hdr, 0); // NFILE (frames)
339        write_i32_le(&mut hdr, 0); // ISTART
340        write_i32_le(&mut hdr, 1); // NSAVC
341        write_i32_le(&mut hdr, 0); // NSTEP
342        write_i32_le(&mut hdr, 0); // 0
343        write_i32_le(&mut hdr, 0); // 0
344        write_i32_le(&mut hdr, 0); // 0
345        write_i32_le(&mut hdr, 0); // 0
346        write_i32_le(&mut hdr, 0); // 0
347        write_f32_le(&mut hdr, dt); // DELTA
348        for _ in 0..9 {
349            write_i32_le(&mut hdr, 0); // padding
350        }
351        write_i32_le(&mut hdr, 24); // CHARMM version marker
352        // The NFILE offset in the full buffer = 4 (Fortran record len) + 4 (CORD) + nfile_offset_in_hdr
353        let nfile_offset = 4 + nfile_offset_in_hdr;
354        fortran_record(&mut buf, &hdr);
355
356        // TITLE block
357        let mut title_block = Vec::new();
358        write_i32_le(&mut title_block, 1); // NTITLE
359        let title = b"OxiPhysics DCD                                                  ";
360        title_block.extend_from_slice(&title[..80.min(title.len())]);
361        if title.len() < 80 {
362            title_block.extend(std::iter::repeat_n(b' ', 80 - title.len()));
363        }
364        fortran_record(&mut buf, &title_block);
365
366        // NATOM block
367        let mut natom_block = Vec::new();
368        write_i32_le(&mut natom_block, n_atoms as i32);
369        fortran_record(&mut buf, &natom_block);
370
371        Self {
372            n_atoms,
373            n_frames: 0,
374            buf,
375            nfile_offset,
376        }
377    }
378
379    /// Append one frame (x, y, z coordinate arrays) to the buffer.
380    ///
381    /// Each array must have length `n_atoms`.
382    pub fn write_frame(&mut self, x: &[f32], y: &[f32], z: &[f32]) {
383        let write_coord = |buf: &mut Vec<u8>, coords: &[f32]| {
384            let mut data = Vec::with_capacity(coords.len() * 4);
385            for &v in coords {
386                write_f32_le(&mut data, v);
387            }
388            fortran_record(buf, &data);
389        };
390        write_coord(&mut self.buf, x);
391        write_coord(&mut self.buf, y);
392        write_coord(&mut self.buf, z);
393        self.n_frames += 1;
394    }
395
396    /// Finalise and return the complete DCD byte buffer.
397    ///
398    /// Patches the `NFILE` field in the header with the actual frame count.
399    pub fn finalize(mut self) -> Vec<u8> {
400        let nf = (self.n_frames as i32).to_le_bytes();
401        self.buf[self.nfile_offset..self.nfile_offset + 4].copy_from_slice(&nf);
402        self.buf
403    }
404}
405
406// ── DcdReader ────────────────────────────────────────────────────────────────
407
408/// Reader for CHARMM DCD trajectory files (pure byte-slice, no I/O).
409#[derive(Debug)]
410pub struct DcdReader {
411    buf: Vec<u8>,
412    /// Number of atoms (from header).
413    pub n_atoms: usize,
414    /// Number of frames (from header).
415    pub n_frames: usize,
416    /// Byte offset where the first frame begins.
417    frames_start: usize,
418    /// Byte size of one frame.
419    frame_size: usize,
420}
421
422impl DcdReader {
423    /// Parse the DCD header from `buf` and return a reader, or `None` on error.
424    pub fn parse_header(buf: Vec<u8>) -> Option<Self> {
425        if buf.len() < 8 {
426            return None;
427        }
428        // Fortran record: 4-byte length prefix
429        let rec_len = i32::from_le_bytes(buf[0..4].try_into().ok()?) as usize;
430        if buf.len() < 4 + rec_len + 4 {
431            return None;
432        }
433        let hdr = &buf[4..4 + rec_len];
434        if &hdr[0..4] != b"CORD" {
435            return None;
436        }
437        let n_frames = i32::from_le_bytes(hdr[4..8].try_into().ok()?) as usize;
438        // Skip to end of first Fortran record
439        let mut offset = 4 + rec_len + 4;
440        // Skip TITLE record
441        if offset + 4 > buf.len() {
442            return None;
443        }
444        let title_len = i32::from_le_bytes(buf[offset..offset + 4].try_into().ok()?) as usize;
445        offset += 4 + title_len + 4;
446        // NATOM record
447        if offset + 4 > buf.len() {
448            return None;
449        }
450        let natom_len = i32::from_le_bytes(buf[offset..offset + 4].try_into().ok()?) as usize;
451        offset += 4;
452        if natom_len < 4 || offset + 4 > buf.len() {
453            return None;
454        }
455        let n_atoms = i32::from_le_bytes(buf[offset..offset + 4].try_into().ok()?) as usize;
456        offset += natom_len + 4;
457        let frames_start = offset;
458        // Each frame: 3 coord arrays, each with Fortran record wrapper
459        let coord_block = 4 + n_atoms * 4 + 4; // len_prefix + data + len_suffix
460        let frame_size = 3 * coord_block;
461        Some(Self {
462            buf,
463            n_atoms,
464            n_frames,
465            frames_start,
466            frame_size,
467        })
468    }
469
470    /// Read the coordinates for frame `frame_idx` (0-based).
471    ///
472    /// Returns `(x, y, z)` each as `Vec`f32`, or `None` if out of range.
473    pub fn read_frame(&self, frame_idx: usize) -> Option<(Vec<f32>, Vec<f32>, Vec<f32>)> {
474        if frame_idx >= self.n_frames {
475            return None;
476        }
477        let start = self.frames_start + frame_idx * self.frame_size;
478        let read_coord = |offset: usize| -> Option<Vec<f32>> {
479            let _rec_len =
480                i32::from_le_bytes(self.buf[offset..offset + 4].try_into().ok()?) as usize;
481            let data_start = offset + 4;
482            let mut coords = Vec::with_capacity(self.n_atoms);
483            for i in 0..self.n_atoms {
484                let o = data_start + i * 4;
485                let v = f32::from_le_bytes(self.buf[o..o + 4].try_into().ok()?);
486                coords.push(v);
487            }
488            Some(coords)
489        };
490        let coord_block = 4 + self.n_atoms * 4 + 4;
491        let x = read_coord(start)?;
492        let y = read_coord(start + coord_block)?;
493        let z = read_coord(start + 2 * coord_block)?;
494        Some((x, y, z))
495    }
496}
497
498// ── XtcEncoder ───────────────────────────────────────────────────────────────
499
500/// Minimal XTC-like lossy compression for particle positions.
501///
502/// Positions are quantised to integer multiples of `1e-3` nm (the default
503/// XTC precision) and stored as big-endian i32 values, preceded by a 4-byte
504/// particle count.
505#[derive(Debug, Default)]
506pub struct XtcEncoder;
507
508impl XtcEncoder {
509    /// Create a new encoder.
510    pub fn new() -> Self {
511        Self
512    }
513
514    /// Compress `positions` into a byte buffer using integer quantisation.
515    ///
516    /// The precision is fixed at 1000 (i.e. 3 decimal places in nm units).
517    pub fn compress_frame(positions: &[[f32; 3]]) -> Vec<u8> {
518        let precision: f32 = 1000.0;
519        let mut buf = Vec::new();
520        let n = positions.len() as u32;
521        buf.extend_from_slice(&n.to_be_bytes());
522        buf.extend_from_slice(&precision.to_bits().to_be_bytes());
523        for p in positions {
524            for &v in p.iter() {
525                let q = (v * precision).round() as i32;
526                buf.extend_from_slice(&q.to_be_bytes());
527            }
528        }
529        buf
530    }
531
532    /// Decompress a byte buffer produced by [`compress_frame`](XtcEncoder::compress_frame).
533    ///
534    /// Returns `None` if `buf` is malformed.
535    pub fn decompress_frame(buf: &[u8], n: usize) -> Option<Vec<[f32; 3]>> {
536        if buf.len() < 8 {
537            return None;
538        }
539        let _stored_n = u32::from_be_bytes(buf[0..4].try_into().ok()?);
540        let precision = f32::from_bits(u32::from_be_bytes(buf[4..8].try_into().ok()?));
541        if precision.abs() < 1e-9 {
542            return None;
543        }
544        let needed = 8 + n * 12;
545        if buf.len() < needed {
546            return None;
547        }
548        let mut positions = Vec::with_capacity(n);
549        let mut offset = 8usize;
550        for _ in 0..n {
551            let qx = i32::from_be_bytes(buf[offset..offset + 4].try_into().ok()?);
552            let qy = i32::from_be_bytes(buf[offset + 4..offset + 8].try_into().ok()?);
553            let qz = i32::from_be_bytes(buf[offset + 8..offset + 12].try_into().ok()?);
554            offset += 12;
555            positions.push([
556                qx as f32 / precision,
557                qy as f32 / precision,
558                qz as f32 / precision,
559            ]);
560        }
561        Some(positions)
562    }
563}
564
565// ── EnergyLogWriter ──────────────────────────────────────────────────────────
566
567/// A step entry in the energy log.
568#[derive(Debug, Clone)]
569pub struct EnergyEntry {
570    /// Simulation step index.
571    pub step: u64,
572    /// Kinetic energy (J or reduced units).
573    pub ke: f64,
574    /// Potential energy (J or reduced units).
575    pub pe: f64,
576    /// Temperature (K).
577    pub temp: f64,
578    /// Pressure (Pa or reduced units).
579    pub pressure: f64,
580}
581
582/// Writer for CSV-like energy time series.
583///
584/// Accumulates entries in memory and can produce a CSV string.
585#[derive(Debug, Default)]
586pub struct EnergyLogWriter {
587    entries: Vec<EnergyEntry>,
588}
589
590impl EnergyLogWriter {
591    /// Create a new empty energy log writer.
592    pub fn new() -> Self {
593        Self {
594            entries: Vec::new(),
595        }
596    }
597
598    /// Append one time step to the log.
599    pub fn write_step(&mut self, step: u64, ke: f64, pe: f64, temp: f64, pressure: f64) {
600        self.entries.push(EnergyEntry {
601            step,
602            ke,
603            pe,
604            temp,
605            pressure,
606        });
607    }
608
609    /// Render the log as a CSV string with a header row.
610    pub fn to_csv_string(&self) -> String {
611        let mut out = String::from("step,ke,pe,total_energy,temperature,pressure\n");
612        for e in &self.entries {
613            out.push_str(&format!(
614                "{},{:.6},{:.6},{:.6},{:.6},{:.6}\n",
615                e.step,
616                e.ke,
617                e.pe,
618                e.ke + e.pe,
619                e.temp,
620                e.pressure
621            ));
622        }
623        out
624    }
625
626    /// Number of entries recorded.
627    pub fn len(&self) -> usize {
628        self.entries.len()
629    }
630
631    /// Returns `true` if no entries have been written.
632    pub fn is_empty(&self) -> bool {
633        self.entries.is_empty()
634    }
635}
636
637// ── CheckpointWriter ─────────────────────────────────────────────────────────
638
639/// Serialises and deserialises simulation checkpoints as byte buffers.
640///
641/// A checkpoint stores positions, velocities, the current step, and the
642/// current simulation time. All values are encoded as little-endian bytes.
643#[derive(Debug, Default)]
644pub struct CheckpointWriter;
645
646impl CheckpointWriter {
647    /// Create a new checkpoint writer/reader.
648    pub fn new() -> Self {
649        Self
650    }
651
652    /// Serialise a simulation state to a byte buffer.
653    ///
654    /// * `positions`  – `N` particle positions `\[x, y, z\]` (f64 metres)
655    /// * `velocities` – `N` particle velocities `\[vx, vy, vz\]` (f64 m/s)
656    /// * `step`       – current step index
657    /// * `time`       – current simulation time (seconds)
658    pub fn save_state(
659        positions: &[[f64; 3]],
660        velocities: &[[f64; 3]],
661        step: u64,
662        time: f64,
663    ) -> Vec<u8> {
664        let n = positions.len();
665        // Header: magic(4) + n(8) + step(8) + time(8) = 28 bytes
666        // Data: n * (3*8 + 3*8) = n * 48 bytes
667        let mut buf = Vec::with_capacity(28 + n * 48);
668        buf.extend_from_slice(b"OXCK"); // magic
669        buf.extend_from_slice(&(n as u64).to_le_bytes());
670        buf.extend_from_slice(&step.to_le_bytes());
671        buf.extend_from_slice(&time.to_le_bytes());
672        for p in positions {
673            for &v in p.iter() {
674                buf.extend_from_slice(&v.to_le_bytes());
675            }
676        }
677        for v in velocities {
678            for &c in v.iter() {
679                buf.extend_from_slice(&c.to_le_bytes());
680            }
681        }
682        buf
683    }
684
685    /// Deserialise a checkpoint from `buf`.
686    ///
687    /// Returns `(positions, velocities, step, time)` or panics if the buffer
688    /// is too short or the magic bytes are wrong.
689    pub fn load_state(buf: &[u8]) -> (Vec<[f64; 3]>, Vec<[f64; 3]>, u64, f64) {
690        assert!(buf.len() >= 28, "checkpoint buffer too short");
691        assert_eq!(&buf[0..4], b"OXCK", "bad checkpoint magic");
692        let n =
693            u64::from_le_bytes(buf[4..12].try_into().expect("slice length must match")) as usize;
694        let step = u64::from_le_bytes(buf[12..20].try_into().expect("slice length must match"));
695        let time = f64::from_le_bytes(buf[20..28].try_into().expect("slice length must match"));
696        let mut offset = 28usize;
697        let mut positions = Vec::with_capacity(n);
698        for _ in 0..n {
699            let x = f64::from_le_bytes(
700                buf[offset..offset + 8]
701                    .try_into()
702                    .expect("slice length must match"),
703            );
704            let y = f64::from_le_bytes(
705                buf[offset + 8..offset + 16]
706                    .try_into()
707                    .expect("slice length must match"),
708            );
709            let z = f64::from_le_bytes(
710                buf[offset + 16..offset + 24]
711                    .try_into()
712                    .expect("slice length must match"),
713            );
714            positions.push([x, y, z]);
715            offset += 24;
716        }
717        let mut velocities = Vec::with_capacity(n);
718        for _ in 0..n {
719            let vx = f64::from_le_bytes(
720                buf[offset..offset + 8]
721                    .try_into()
722                    .expect("slice length must match"),
723            );
724            let vy = f64::from_le_bytes(
725                buf[offset + 8..offset + 16]
726                    .try_into()
727                    .expect("slice length must match"),
728            );
729            let vz = f64::from_le_bytes(
730                buf[offset + 16..offset + 24]
731                    .try_into()
732                    .expect("slice length must match"),
733            );
734            velocities.push([vx, vy, vz]);
735            offset += 24;
736        }
737        (positions, velocities, step, time)
738    }
739}
740
741// ── Tests ─────────────────────────────────────────────────────────────────────
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746
747    // ── BinaryHeader tests ────────────────────────────────────────────────
748
749    #[test]
750    fn test_header_new_fields() {
751        let h = BinaryHeader::new(100, 50, 0.001);
752        assert_eq!(h.magic, OXIPHY_MAGIC);
753        assert_eq!(h.version, 1);
754        assert_eq!(h.n_particles, 100);
755        assert_eq!(h.n_frames, 50);
756        assert!((h.dt - 0.001).abs() < 1e-15);
757    }
758
759    #[test]
760    fn test_header_validate_ok() {
761        let h = BinaryHeader::new(10, 5, 0.01);
762        assert!(h.validate());
763    }
764
765    #[test]
766    fn test_header_validate_bad_magic() {
767        let mut h = BinaryHeader::new(10, 5, 0.01);
768        h.magic = [0u8; 4];
769        assert!(!h.validate());
770    }
771
772    #[test]
773    fn test_header_validate_bad_version() {
774        let mut h = BinaryHeader::new(10, 5, 0.01);
775        h.version = 0;
776        assert!(!h.validate());
777    }
778
779    #[test]
780    fn test_header_write_read_roundtrip() {
781        let h = BinaryHeader::new(42, 7, 0.002);
782        let mut buf = Vec::new();
783        h.write(&mut buf);
784        assert_eq!(buf.len(), BinaryHeader::SIZE);
785        let h2 = BinaryHeader::read(&buf).unwrap();
786        assert_eq!(h, h2);
787    }
788
789    #[test]
790    fn test_header_read_too_short() {
791        let buf = vec![0u8; 10];
792        assert!(BinaryHeader::read(&buf).is_none());
793    }
794
795    #[test]
796    fn test_header_size_constant() {
797        let mut buf = Vec::new();
798        BinaryHeader::new(1, 1, 1.0).write(&mut buf);
799        assert_eq!(buf.len(), BinaryHeader::SIZE);
800    }
801
802    // ── ParticleFrame tests ───────────────────────────────────────────────
803
804    #[test]
805    fn test_frame_new_sizes() {
806        let f = ParticleFrame::new(0, 0.0, 5);
807        assert_eq!(f.positions.len(), 5);
808        assert_eq!(f.velocities.len(), 5);
809        assert_eq!(f.masses.len(), 5);
810    }
811
812    #[test]
813    fn test_frame_serialized_size() {
814        // 8 + 8 + 3*12 + 3*12 + 3*4 = 16 + 36 + 36 + 12 = 100
815        assert_eq!(ParticleFrame::serialized_size(3), 100);
816    }
817
818    #[test]
819    fn test_frame_serialize_deserialize_roundtrip() {
820        let mut f = ParticleFrame::new(7, 0.014, 3);
821        f.positions[0] = [1.0, 2.0, 3.0];
822        f.velocities[1] = [0.5, 0.6, 0.7];
823        f.masses[2] = 4.0;
824        let mut buf = Vec::new();
825        f.serialize(&mut buf);
826        let f2 = ParticleFrame::deserialize(&buf, 3).unwrap();
827        assert_eq!(f2.step, 7);
828        assert!((f2.time - 0.014).abs() < 1e-12);
829        assert_eq!(f2.positions[0], [1.0f32, 2.0, 3.0]);
830        assert!((f2.velocities[1][1] - 0.6).abs() < 1e-6);
831        assert!((f2.masses[2] - 4.0).abs() < 1e-6);
832    }
833
834    #[test]
835    fn test_frame_deserialize_too_short() {
836        let buf = vec![0u8; 10];
837        assert!(ParticleFrame::deserialize(&buf, 5).is_none());
838    }
839
840    #[test]
841    fn test_frame_zero_particles() {
842        let f = ParticleFrame::new(0, 0.0, 0);
843        let mut buf = Vec::new();
844        f.serialize(&mut buf);
845        let f2 = ParticleFrame::deserialize(&buf, 0).unwrap();
846        assert_eq!(f2.positions.len(), 0);
847    }
848
849    // ── BinaryTrajectoryWriter/Reader tests ───────────────────────────────
850
851    #[test]
852    fn test_trajectory_write_read_roundtrip() {
853        let n_particles = 4u64;
854        let dt = 0.001;
855        let mut writer = BinaryTrajectoryWriter::new(n_particles, dt);
856        for step in 0u64..3 {
857            let mut frame = ParticleFrame::new(step, step as f64 * dt, n_particles as usize);
858            for i in 0..n_particles as usize {
859                frame.positions[i] = [i as f32, step as f32, 0.0];
860            }
861            writer.write_frame(&frame);
862        }
863        assert_eq!(writer.frame_count(), 3);
864        let data = writer.finalize();
865
866        let mut reader = BinaryTrajectoryReader::new(data).unwrap();
867        assert_eq!(reader.n_frames(), 3);
868        for step in 0u64..3 {
869            let frame = reader.read_frame().unwrap();
870            assert_eq!(frame.step, step);
871        }
872    }
873
874    #[test]
875    fn test_trajectory_reader_invalid_magic() {
876        let mut buf = vec![0u8; BinaryHeader::SIZE + 100];
877        buf[0] = b'X';
878        assert!(BinaryTrajectoryReader::new(buf).is_none());
879    }
880
881    #[test]
882    fn test_trajectory_seek_frame() {
883        let n = 2u64;
884        let mut writer = BinaryTrajectoryWriter::new(n, 0.01);
885        for step in 0u64..5 {
886            let frame = ParticleFrame::new(step, step as f64 * 0.01, n as usize);
887            writer.write_frame(&frame);
888        }
889        let data = writer.finalize();
890        let mut reader = BinaryTrajectoryReader::new(data).unwrap();
891        reader.seek_frame(3);
892        let frame = reader.read_frame().unwrap();
893        assert_eq!(frame.step, 3);
894    }
895
896    #[test]
897    fn test_trajectory_read_past_end() {
898        let n = 2u64;
899        let mut writer = BinaryTrajectoryWriter::new(n, 0.01);
900        writer.write_frame(&ParticleFrame::new(0, 0.0, n as usize));
901        let data = writer.finalize();
902        let mut reader = BinaryTrajectoryReader::new(data).unwrap();
903        assert!(reader.read_frame().is_some());
904        assert!(reader.read_frame().is_none());
905    }
906
907    #[test]
908    fn test_trajectory_frame_count() {
909        let mut writer = BinaryTrajectoryWriter::new(1, 0.01);
910        assert_eq!(writer.frame_count(), 0);
911        writer.write_frame(&ParticleFrame::new(0, 0.0, 1));
912        assert_eq!(writer.frame_count(), 1);
913    }
914
915    // ── DcdWriter/Reader tests ────────────────────────────────────────────
916
917    #[test]
918    fn test_dcd_write_read_roundtrip() {
919        let n_atoms = 5;
920        let mut writer = DcdWriter::new(n_atoms, 0.002);
921        let x: Vec<f32> = (0..n_atoms).map(|i| i as f32).collect();
922        let y: Vec<f32> = (0..n_atoms).map(|i| i as f32 * 2.0).collect();
923        let z: Vec<f32> = (0..n_atoms).map(|i| i as f32 * 3.0).collect();
924        writer.write_frame(&x, &y, &z);
925        let buf = writer.finalize();
926
927        let reader = DcdReader::parse_header(buf).unwrap();
928        assert_eq!(reader.n_atoms, n_atoms);
929        assert_eq!(reader.n_frames, 1);
930        let (rx, ry, rz) = reader.read_frame(0).unwrap();
931        assert!((rx[2] - 2.0).abs() < 1e-5);
932        assert!((ry[2] - 4.0).abs() < 1e-5);
933        assert!((rz[2] - 6.0).abs() < 1e-5);
934    }
935
936    #[test]
937    fn test_dcd_multiple_frames() {
938        let n_atoms = 3;
939        let mut writer = DcdWriter::new(n_atoms, 0.001);
940        for frame_i in 0..4 {
941            let x: Vec<f32> = vec![frame_i as f32; n_atoms];
942            let y: Vec<f32> = vec![0.0; n_atoms];
943            let z: Vec<f32> = vec![0.0; n_atoms];
944            writer.write_frame(&x, &y, &z);
945        }
946        let buf = writer.finalize();
947        let reader = DcdReader::parse_header(buf).unwrap();
948        assert_eq!(reader.n_frames, 4);
949        let (rx, _, _) = reader.read_frame(3).unwrap();
950        assert!((rx[0] - 3.0).abs() < 1e-5);
951    }
952
953    #[test]
954    fn test_dcd_out_of_range_frame() {
955        let mut writer = DcdWriter::new(2, 0.001);
956        writer.write_frame(&[1.0, 2.0], &[0.0, 0.0], &[0.0, 0.0]);
957        let buf = writer.finalize();
958        let reader = DcdReader::parse_header(buf).unwrap();
959        assert!(reader.read_frame(99).is_none());
960    }
961
962    #[test]
963    fn test_dcd_parse_header_too_short() {
964        assert!(DcdReader::parse_header(vec![0u8; 3]).is_none());
965    }
966
967    #[test]
968    fn test_dcd_zero_frames() {
969        let writer = DcdWriter::new(4, 0.001);
970        let buf = writer.finalize();
971        let reader = DcdReader::parse_header(buf).unwrap();
972        assert_eq!(reader.n_frames, 0);
973        assert!(reader.read_frame(0).is_none());
974    }
975
976    // ── XtcEncoder tests ─────────────────────────────────────────────────
977
978    #[test]
979    fn test_xtc_compress_decompress_roundtrip() {
980        let positions: Vec<[f32; 3]> =
981            vec![[1.0, 2.0, 3.0], [4.5, -1.5, 0.0], [-3.0, 0.001, 100.0]];
982        let buf = XtcEncoder::compress_frame(&positions);
983        let decoded = XtcEncoder::decompress_frame(&buf, 3).unwrap();
984        for (orig, dec) in positions.iter().zip(decoded.iter()) {
985            for i in 0..3 {
986                assert!(
987                    (orig[i] - dec[i]).abs() < 0.002,
988                    "mismatch at component {i}"
989                );
990            }
991        }
992    }
993
994    #[test]
995    fn test_xtc_empty_frame() {
996        let buf = XtcEncoder::compress_frame(&[]);
997        let decoded = XtcEncoder::decompress_frame(&buf, 0).unwrap();
998        assert!(decoded.is_empty());
999    }
1000
1001    #[test]
1002    fn test_xtc_decompress_too_short() {
1003        assert!(XtcEncoder::decompress_frame(&[0u8; 5], 3).is_none());
1004    }
1005
1006    #[test]
1007    fn test_xtc_single_particle() {
1008        let pos = vec![[0.123f32, -0.456, 7.89]];
1009        let buf = XtcEncoder::compress_frame(&pos);
1010        let dec = XtcEncoder::decompress_frame(&buf, 1).unwrap();
1011        assert!((dec[0][0] - 0.123).abs() < 0.002);
1012        assert!((dec[0][2] - 7.89).abs() < 0.002);
1013    }
1014
1015    #[test]
1016    fn test_xtc_compressed_size() {
1017        let n = 10;
1018        let positions = vec![[0.0f32; 3]; n];
1019        let buf = XtcEncoder::compress_frame(&positions);
1020        // 4 (n) + 4 (precision) + n*12 = 8 + 120 = 128
1021        assert_eq!(buf.len(), 8 + n * 12);
1022    }
1023
1024    // ── EnergyLogWriter tests ─────────────────────────────────────────────
1025
1026    #[test]
1027    fn test_energy_log_empty() {
1028        let log = EnergyLogWriter::new();
1029        assert!(log.is_empty());
1030        let csv = log.to_csv_string();
1031        assert!(csv.starts_with("step,ke,pe,total_energy"));
1032    }
1033
1034    #[test]
1035    fn test_energy_log_write_step() {
1036        let mut log = EnergyLogWriter::new();
1037        log.write_step(0, 10.0, -20.0, 300.0, 101325.0);
1038        assert_eq!(log.len(), 1);
1039    }
1040
1041    #[test]
1042    fn test_energy_log_csv_total_energy() {
1043        let mut log = EnergyLogWriter::new();
1044        log.write_step(1, 5.0, -3.0, 200.0, 1.0);
1045        let csv = log.to_csv_string();
1046        assert!(csv.contains("2.000000"), "expected total energy 2.0 in csv");
1047    }
1048
1049    #[test]
1050    fn test_energy_log_multiple_steps() {
1051        let mut log = EnergyLogWriter::new();
1052        for i in 0..10u64 {
1053            log.write_step(i, i as f64, -(i as f64), 300.0, 1.0);
1054        }
1055        assert_eq!(log.len(), 10);
1056        let csv = log.to_csv_string();
1057        let lines: Vec<&str> = csv.lines().collect();
1058        assert_eq!(lines.len(), 11); // header + 10 data
1059    }
1060
1061    #[test]
1062    fn test_energy_log_not_empty_after_write() {
1063        let mut log = EnergyLogWriter::new();
1064        log.write_step(0, 1.0, 2.0, 3.0, 4.0);
1065        assert!(!log.is_empty());
1066    }
1067
1068    // ── CheckpointWriter tests ────────────────────────────────────────────
1069
1070    #[test]
1071    fn test_checkpoint_roundtrip() {
1072        let positions = vec![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
1073        let velocities = vec![[0.1f64, 0.2, 0.3], [0.4, 0.5, 0.6]];
1074        let step = 42u64;
1075        let time = 0.042;
1076        let buf = CheckpointWriter::save_state(&positions, &velocities, step, time);
1077        let (pos2, vel2, step2, time2) = CheckpointWriter::load_state(&buf);
1078        assert_eq!(step2, step);
1079        assert!((time2 - time).abs() < 1e-15);
1080        assert!((pos2[0][0] - 1.0).abs() < 1e-15);
1081        assert!((vel2[1][2] - 0.6).abs() < 1e-15);
1082    }
1083
1084    #[test]
1085    fn test_checkpoint_zero_particles() {
1086        let buf = CheckpointWriter::save_state(&[], &[], 0, 0.0);
1087        let (pos, vel, step, time) = CheckpointWriter::load_state(&buf);
1088        assert!(pos.is_empty());
1089        assert!(vel.is_empty());
1090        assert_eq!(step, 0);
1091        assert!((time).abs() < 1e-15);
1092    }
1093
1094    #[test]
1095    fn test_checkpoint_magic_bytes() {
1096        let buf = CheckpointWriter::save_state(&[], &[], 0, 0.0);
1097        assert_eq!(&buf[0..4], b"OXCK");
1098    }
1099
1100    #[test]
1101    fn test_checkpoint_step_and_time() {
1102        let pos = vec![[0.0f64; 3]];
1103        let vel = vec![[0.0f64; 3]];
1104        let buf = CheckpointWriter::save_state(&pos, &vel, 9999, 99.99);
1105        let (_, _, s, t) = CheckpointWriter::load_state(&buf);
1106        assert_eq!(s, 9999);
1107        assert!((t - 99.99).abs() < 1e-10);
1108    }
1109
1110    #[test]
1111    fn test_checkpoint_positions_preserved() {
1112        let pos: Vec<[f64; 3]> = (0..5).map(|i| [i as f64; 3]).collect();
1113        let vel = vec![[0.0f64; 3]; 5];
1114        let buf = CheckpointWriter::save_state(&pos, &vel, 0, 0.0);
1115        let (pos2, _, _, _) = CheckpointWriter::load_state(&buf);
1116        for i in 0..5 {
1117            assert!((pos2[i][0] - i as f64).abs() < 1e-15);
1118        }
1119    }
1120
1121    #[test]
1122    fn test_checkpoint_velocities_preserved() {
1123        let pos = vec![[0.0f64; 3]; 3];
1124        let vel: Vec<[f64; 3]> = (0..3).map(|i| [i as f64 * 0.5; 3]).collect();
1125        let buf = CheckpointWriter::save_state(&pos, &vel, 0, 0.0);
1126        let (_, vel2, _, _) = CheckpointWriter::load_state(&buf);
1127        assert!((vel2[2][1] - 1.0).abs() < 1e-15);
1128    }
1129
1130    // ── Integration / edge-case tests ─────────────────────────────────────
1131
1132    #[test]
1133    fn test_full_pipeline_write_read() {
1134        // Write a trajectory with 2 particles and 5 frames, then read all back
1135        let n = 2u64;
1136        let dt = 0.005;
1137        let mut writer = BinaryTrajectoryWriter::new(n, dt);
1138        for step in 0u64..5 {
1139            let mut frame = ParticleFrame::new(step, step as f64 * dt, n as usize);
1140            frame.positions[0] = [step as f32, 0.0, 0.0];
1141            frame.positions[1] = [0.0, step as f32, 0.0];
1142            frame.masses = vec![1.0, 2.0];
1143            writer.write_frame(&frame);
1144        }
1145        let data = writer.finalize();
1146        let mut reader = BinaryTrajectoryReader::new(data).unwrap();
1147        assert_eq!(reader.n_frames(), 5);
1148        for step in 0u64..5 {
1149            let frame = reader.read_frame().unwrap();
1150            assert_eq!(frame.step, step);
1151            assert!((frame.positions[0][0] - step as f32).abs() < 1e-6);
1152            assert!((frame.masses[1] - 2.0).abs() < 1e-6);
1153        }
1154        assert!(reader.read_frame().is_none());
1155    }
1156
1157    #[test]
1158    fn test_energy_log_csv_format() {
1159        let mut log = EnergyLogWriter::new();
1160        log.write_step(0, 1.0, -1.0, 300.0, 1.0);
1161        log.write_step(1, 2.0, -2.0, 310.0, 1.1);
1162        let csv = log.to_csv_string();
1163        let mut lines = csv.lines();
1164        let header = lines.next().unwrap();
1165        assert!(header.contains("step"));
1166        assert!(header.contains("ke"));
1167        assert!(header.contains("pe"));
1168        let first_data = lines.next().unwrap();
1169        assert!(first_data.starts_with("0,"));
1170    }
1171
1172    #[test]
1173    fn test_dcd_n_atoms_in_reader() {
1174        let n_atoms = 7;
1175        let writer = DcdWriter::new(n_atoms, 0.002);
1176        let buf = writer.finalize();
1177        let reader = DcdReader::parse_header(buf).unwrap();
1178        assert_eq!(reader.n_atoms, n_atoms);
1179    }
1180
1181    #[test]
1182    fn test_xtc_precision_loss_is_small() {
1183        let pos = vec![[3.15625f32, 2.71875, -1.40625]];
1184        let buf = XtcEncoder::compress_frame(&pos);
1185        let dec = XtcEncoder::decompress_frame(&buf, 1).unwrap();
1186        for i in 0..3 {
1187            assert!(
1188                (pos[0][i] - dec[0][i]).abs() < 0.002,
1189                "precision too large for component {i}"
1190            );
1191        }
1192    }
1193
1194    #[test]
1195    fn test_binary_header_dt_preserved() {
1196        let h = BinaryHeader::new(10, 10, std::f64::consts::PI);
1197        let mut buf = Vec::new();
1198        h.write(&mut buf);
1199        let h2 = BinaryHeader::read(&buf).unwrap();
1200        assert!((h2.dt - std::f64::consts::PI).abs() < 1e-14);
1201    }
1202
1203    #[test]
1204    fn test_particle_frame_step_preserved() {
1205        let f = ParticleFrame::new(12345, 1.23, 1);
1206        let mut buf = Vec::new();
1207        f.serialize(&mut buf);
1208        let f2 = ParticleFrame::deserialize(&buf, 1).unwrap();
1209        assert_eq!(f2.step, 12345);
1210    }
1211}