Skip to main content

oxiphysics_io/
simulation_io.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Physics simulation I/O formats.
5//!
6//! Provides writers and readers for full physics scenes (bodies, joints,
7//! materials), multi-body trajectories, contact force logs, energy logs,
8//! simulation checkpoints, VTK time-series output, and XDMF wrappers.
9
10#![allow(dead_code)]
11
12use std::io::{BufWriter, Write};
13use std::path::Path;
14
15// ---------------------------------------------------------------------------
16// Helper free functions
17// ---------------------------------------------------------------------------
18
19/// Encode a quaternion (w, x, y, z) as four f32 values in little-endian bytes.
20pub fn encode_quaternion_f32(w: f64, x: f64, y: f64, z: f64) -> [u8; 16] {
21    let mut buf = [0u8; 16];
22    buf[0..4].copy_from_slice(&(w as f32).to_le_bytes());
23    buf[4..8].copy_from_slice(&(x as f32).to_le_bytes());
24    buf[8..12].copy_from_slice(&(y as f32).to_le_bytes());
25    buf[12..16].copy_from_slice(&(z as f32).to_le_bytes());
26    buf
27}
28
29/// Pack a slice of \[f64; 3\] vectors into an interleaved f32 little-endian byte
30/// buffer.
31pub fn pack_float3_array(data: &[[f64; 3]]) -> Vec<u8> {
32    let mut buf = Vec::with_capacity(data.len() * 12);
33    for &[x, y, z] in data {
34        buf.extend_from_slice(&(x as f32).to_le_bytes());
35        buf.extend_from_slice(&(y as f32).to_le_bytes());
36        buf.extend_from_slice(&(z as f32).to_le_bytes());
37    }
38    buf
39}
40
41/// Write a ParaView PVD collection XML file pointing to per-step VTU files.
42pub fn write_pvd_collection<W: Write>(
43    writer: &mut W,
44    entries: &[(f64, &str)],
45) -> std::io::Result<()> {
46    writeln!(writer, r#"<?xml version="1.0"?>"#)?;
47    writeln!(writer, r#"<VTKFile type="Collection" version="0.1">"#)?;
48    writeln!(writer, "  <Collection>")?;
49    for &(time, path) in entries {
50        writeln!(
51            writer,
52            r#"    <DataSet timestep="{time}" group="" part="0" file="{path}"/>"#
53        )?;
54    }
55    writeln!(writer, "  </Collection>")?;
56    writeln!(writer, "</VTKFile>")?;
57    Ok(())
58}
59
60/// Parse a simple XDMF file and return a list of (time, data_file) pairs.
61pub fn read_xdmf_timesteps(xml: &str) -> Vec<(f64, String)> {
62    let mut result = Vec::new();
63    for line in xml.lines() {
64        let line = line.trim();
65        if line.starts_with("<Grid") {
66            // Extract Time value and DataItem reference
67            if let Some(tv) = extract_attr(line, "Time")
68                && let Ok(t) = tv.parse::<f64>()
69            {
70                result.push((t, String::new()));
71            }
72        } else if line.starts_with("<DataItem")
73            && let Some(href) = extract_attr(line, "href")
74            && let Some(last) = result.last_mut()
75            && last.1.is_empty()
76        {
77            last.1 = href.to_string();
78        }
79    }
80    result
81}
82
83fn extract_attr<'a>(s: &'a str, attr: &str) -> Option<&'a str> {
84    let pat = format!("{attr}=\"");
85    let start = s.find(pat.as_str())? + pat.len();
86    let end = s[start..].find('"')? + start;
87    Some(&s[start..end])
88}
89
90// ---------------------------------------------------------------------------
91// Body / scene types
92// ---------------------------------------------------------------------------
93
94/// A rigid body descriptor for scene serialization.
95#[derive(Debug, Clone)]
96pub struct BodyDesc {
97    /// Unique body identifier.
98    pub id: u64,
99    /// World-space position.
100    pub position: [f64; 3],
101    /// Orientation as quaternion (w, x, y, z).
102    pub orientation: [f64; 4],
103    /// Linear velocity.
104    pub linear_velocity: [f64; 3],
105    /// Angular velocity.
106    pub angular_velocity: [f64; 3],
107    /// Mass (kg).
108    pub mass: f64,
109    /// Shape type tag.
110    pub shape_tag: String,
111}
112
113impl BodyDesc {
114    /// Construct a body descriptor.
115    pub fn new(id: u64, position: [f64; 3]) -> Self {
116        BodyDesc {
117            id,
118            position,
119            orientation: [1.0, 0.0, 0.0, 0.0],
120            linear_velocity: [0.0; 3],
121            angular_velocity: [0.0; 3],
122            mass: 1.0,
123            shape_tag: "sphere".to_string(),
124        }
125    }
126
127    /// Serialize to a simple JSON object string.
128    pub fn to_json(&self) -> String {
129        format!(
130            r#"{{"id":{id},"pos":[{px},{py},{pz}],"quat":[{qw},{qx},{qy},{qz}],"linvel":[{lvx},{lvy},{lvz}],"angvel":[{avx},{avy},{avz}],"mass":{m},"shape":"{s}"}}"#,
131            id = self.id,
132            px = self.position[0],
133            py = self.position[1],
134            pz = self.position[2],
135            qw = self.orientation[0],
136            qx = self.orientation[1],
137            qy = self.orientation[2],
138            qz = self.orientation[3],
139            lvx = self.linear_velocity[0],
140            lvy = self.linear_velocity[1],
141            lvz = self.linear_velocity[2],
142            avx = self.angular_velocity[0],
143            avy = self.angular_velocity[1],
144            avz = self.angular_velocity[2],
145            m = self.mass,
146            s = self.shape_tag,
147        )
148    }
149}
150
151/// A joint descriptor.
152#[derive(Debug, Clone)]
153pub struct JointDesc {
154    /// Unique joint id.
155    pub id: u64,
156    /// Body A id.
157    pub body_a: u64,
158    /// Body B id.
159    pub body_b: u64,
160    /// Joint type.
161    pub joint_type: String,
162    /// Anchor point on body A (local space).
163    pub anchor_a: [f64; 3],
164    /// Anchor point on body B (local space).
165    pub anchor_b: [f64; 3],
166}
167
168impl JointDesc {
169    /// Construct a joint descriptor.
170    pub fn new(id: u64, body_a: u64, body_b: u64, joint_type: &str) -> Self {
171        JointDesc {
172            id,
173            body_a,
174            body_b,
175            joint_type: joint_type.to_string(),
176            anchor_a: [0.0; 3],
177            anchor_b: [0.0; 3],
178        }
179    }
180}
181
182/// A physics material descriptor.
183#[derive(Debug, Clone)]
184pub struct MaterialDesc {
185    /// Material id.
186    pub id: u64,
187    /// Restitution coefficient.
188    pub restitution: f64,
189    /// Dynamic friction coefficient.
190    pub friction: f64,
191    /// Density (kg/m³).
192    pub density: f64,
193}
194
195impl MaterialDesc {
196    /// Construct a material descriptor.
197    pub fn new(id: u64, restitution: f64, friction: f64, density: f64) -> Self {
198        MaterialDesc {
199            id,
200            restitution,
201            friction,
202            density,
203        }
204    }
205}
206
207// ---------------------------------------------------------------------------
208// PhysicsSceneWriter
209// ---------------------------------------------------------------------------
210
211/// Write a full physics scene (bodies, joints, materials) to JSON or binary.
212pub struct PhysicsSceneWriter {
213    /// Bodies in the scene.
214    pub bodies: Vec<BodyDesc>,
215    /// Joints in the scene.
216    pub joints: Vec<JointDesc>,
217    /// Materials in the scene.
218    pub materials: Vec<MaterialDesc>,
219}
220
221impl PhysicsSceneWriter {
222    /// Construct an empty scene writer.
223    pub fn new() -> Self {
224        PhysicsSceneWriter {
225            bodies: Vec::new(),
226            joints: Vec::new(),
227            materials: Vec::new(),
228        }
229    }
230
231    /// Add a body to the scene.
232    pub fn add_body(&mut self, body: BodyDesc) {
233        self.bodies.push(body);
234    }
235
236    /// Add a joint to the scene.
237    pub fn add_joint(&mut self, joint: JointDesc) {
238        self.joints.push(joint);
239    }
240
241    /// Add a material to the scene.
242    pub fn add_material(&mut self, mat: MaterialDesc) {
243        self.materials.push(mat);
244    }
245
246    /// Serialize the scene to JSON and write to `writer`.
247    pub fn write_json<W: Write>(&self, writer: &mut W) -> std::io::Result<()> {
248        writeln!(writer, "{{")?;
249        writeln!(writer, "  \"bodies\": [")?;
250        for (i, b) in self.bodies.iter().enumerate() {
251            let comma = if i + 1 < self.bodies.len() { "," } else { "" };
252            writeln!(writer, "    {}{}", b.to_json(), comma)?;
253        }
254        writeln!(writer, "  ],")?;
255        writeln!(writer, "  \"joints\": [")?;
256        for (i, j) in self.joints.iter().enumerate() {
257            let comma = if i + 1 < self.joints.len() { "," } else { "" };
258            writeln!(
259                writer,
260                r#"    {{"id":{id},"bodyA":{ba},"bodyB":{bb},"type":"{jt}"}}{comma}"#,
261                id = j.id,
262                ba = j.body_a,
263                bb = j.body_b,
264                jt = j.joint_type
265            )?;
266        }
267        writeln!(writer, "  ],")?;
268        writeln!(writer, "  \"materials\": [")?;
269        for (i, m) in self.materials.iter().enumerate() {
270            let comma = if i + 1 < self.materials.len() {
271                ","
272            } else {
273                ""
274            };
275            writeln!(
276                writer,
277                r#"    {{"id":{id},"restitution":{r},"friction":{f},"density":{d}}}{comma}"#,
278                id = m.id,
279                r = m.restitution,
280                f = m.friction,
281                d = m.density
282            )?;
283        }
284        writeln!(writer, "  ]")?;
285        writeln!(writer, "}}")?;
286        Ok(())
287    }
288
289    /// Write to a file path.
290    pub fn write_to_file(&self, path: &str) -> crate::Result<()> {
291        let file = std::fs::File::create(Path::new(path))?;
292        let mut w = BufWriter::new(file);
293        self.write_json(&mut w)?;
294        w.flush()?;
295        Ok(())
296    }
297}
298
299impl Default for PhysicsSceneWriter {
300    fn default() -> Self {
301        Self::new()
302    }
303}
304
305// ---------------------------------------------------------------------------
306// PhysicsSceneReader
307// ---------------------------------------------------------------------------
308
309/// Read a physics scene from JSON, validate it, and reconstruct a body list.
310pub struct PhysicsSceneReader;
311
312impl PhysicsSceneReader {
313    /// Parse a JSON scene string and return body descriptors.
314    /// This is a simplified parser for the format produced by [`PhysicsSceneWriter`].
315    pub fn read_json(json: &str) -> Vec<BodyDesc> {
316        let mut bodies = Vec::new();
317        // Very simple line-by-line parser looking for id fields
318        let mut in_bodies = false;
319        let mut cur_id: Option<u64> = None;
320        let mut cur_pos = [0.0f64; 3];
321        for line in json.lines() {
322            let line = line.trim();
323            if line.contains("\"bodies\"") {
324                in_bodies = true;
325            }
326            if line.contains("\"joints\"") {
327                in_bodies = false;
328            }
329            if !in_bodies {
330                continue;
331            }
332            if let Some(id) = parse_u64_field(line, "id") {
333                cur_id = Some(id);
334            }
335            if let Some(pos) = parse_float3_field(line, "pos") {
336                cur_pos = pos;
337            }
338            if line.contains('}') && cur_id.is_some() {
339                bodies.push(BodyDesc::new(
340                    cur_id.expect("value should be present"),
341                    cur_pos,
342                ));
343                cur_id = None;
344                cur_pos = [0.0; 3];
345            }
346        }
347        bodies
348    }
349
350    /// Validate a list of body descriptors (checks for duplicate IDs).
351    pub fn validate(bodies: &[BodyDesc]) -> bool {
352        let mut ids = std::collections::HashSet::new();
353        bodies.iter().all(|b| ids.insert(b.id))
354    }
355}
356
357fn parse_u64_field(s: &str, field: &str) -> Option<u64> {
358    let pat = format!("\"{field}\":");
359    let idx = s.find(pat.as_str())? + pat.len();
360    let rest = s[idx..].trim_start();
361    rest.split([',', '}', ' ', '\t'])
362        .next()?
363        .trim()
364        .parse()
365        .ok()
366}
367
368fn parse_float3_field(s: &str, field: &str) -> Option<[f64; 3]> {
369    let pat = format!("\"{field}\":");
370    let idx = s.find(pat.as_str())? + pat.len();
371    let rest = s[idx..].trim_start();
372    if rest.starts_with('[') {
373        let end = rest.find(']')?;
374        let inner = &rest[1..end];
375        let vals: Vec<f64> = inner
376            .split(',')
377            .filter_map(|v| v.trim().parse().ok())
378            .collect();
379        if vals.len() == 3 {
380            return Some([vals[0], vals[1], vals[2]]);
381        }
382    }
383    None
384}
385
386// ---------------------------------------------------------------------------
387// TrajectoryWriter
388// ---------------------------------------------------------------------------
389
390/// A single trajectory frame: positions, velocities, quaternions.
391#[derive(Debug, Clone)]
392pub struct TrajectoryFrame {
393    /// Frame time.
394    pub time: f64,
395    /// Per-body positions.
396    pub positions: Vec<[f64; 3]>,
397    /// Per-body velocities.
398    pub velocities: Vec<[f64; 3]>,
399    /// Per-body quaternions (w, x, y, z).
400    pub quaternions: Vec<[f64; 4]>,
401}
402
403impl TrajectoryFrame {
404    /// Construct an empty frame at the given time.
405    pub fn new(time: f64) -> Self {
406        TrajectoryFrame {
407            time,
408            positions: Vec::new(),
409            velocities: Vec::new(),
410            quaternions: Vec::new(),
411        }
412    }
413}
414
415/// Writer for multi-body trajectory files: positions, velocities, quaternions
416/// per timestep.
417pub struct TrajectoryWriter {
418    frames: Vec<TrajectoryFrame>,
419}
420
421impl TrajectoryWriter {
422    /// Construct an empty trajectory writer.
423    pub fn new() -> Self {
424        TrajectoryWriter { frames: Vec::new() }
425    }
426
427    /// Append a frame.
428    pub fn push_frame(&mut self, frame: TrajectoryFrame) {
429        self.frames.push(frame);
430    }
431
432    /// Number of frames.
433    pub fn num_frames(&self) -> usize {
434        self.frames.len()
435    }
436
437    /// Write trajectory to a binary file.
438    /// Format: header "TRAJ\0" + n_frames (u64 LE) + per-frame data.
439    pub fn write_binary<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
440        w.write_all(b"TRAJ\0")?;
441        w.write_all(&(self.frames.len() as u64).to_le_bytes())?;
442        for f in &self.frames {
443            w.write_all(&f.time.to_le_bytes())?;
444            w.write_all(&(f.positions.len() as u64).to_le_bytes())?;
445            let pos_bytes = pack_float3_array(&f.positions);
446            w.write_all(&pos_bytes)?;
447            let vel_bytes = pack_float3_array(&f.velocities);
448            w.write_all(&vel_bytes)?;
449            for &q in &f.quaternions {
450                let bytes = encode_quaternion_f32(q[0], q[1], q[2], q[3]);
451                w.write_all(&bytes)?;
452            }
453        }
454        Ok(())
455    }
456
457    /// Write trajectory to a CSV file (one row per body per frame).
458    pub fn write_csv<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
459        writeln!(w, "time,body,px,py,pz,vx,vy,vz,qw,qx,qy,qz")?;
460        for f in &self.frames {
461            let nb = f.positions.len();
462            for i in 0..nb {
463                let p = f.positions.get(i).copied().unwrap_or([0.0; 3]);
464                let v = f.velocities.get(i).copied().unwrap_or([0.0; 3]);
465                let q = f
466                    .quaternions
467                    .get(i)
468                    .copied()
469                    .unwrap_or([1.0, 0.0, 0.0, 0.0]);
470                writeln!(
471                    w,
472                    "{},{},{},{},{},{},{},{},{},{},{},{}",
473                    f.time, i, p[0], p[1], p[2], v[0], v[1], v[2], q[0], q[1], q[2], q[3]
474                )?;
475            }
476        }
477        Ok(())
478    }
479}
480
481impl Default for TrajectoryWriter {
482    fn default() -> Self {
483        Self::new()
484    }
485}
486
487// ---------------------------------------------------------------------------
488// TrajectoryReader
489// ---------------------------------------------------------------------------
490
491/// Reader for binary trajectory files produced by [`TrajectoryWriter`].
492pub struct TrajectoryReader {
493    frames: Vec<TrajectoryFrame>,
494    cursor: usize,
495}
496
497impl TrajectoryReader {
498    /// Read from a byte buffer.
499    pub fn from_bytes(data: &[u8]) -> Option<Self> {
500        if data.len() < 13 || &data[0..5] != b"TRAJ\0" {
501            return None;
502        }
503        let n_frames = u64::from_le_bytes(data[5..13].try_into().ok()?) as usize;
504        let mut frames = Vec::with_capacity(n_frames);
505        let mut pos = 13;
506        for _ in 0..n_frames {
507            if pos + 16 > data.len() {
508                break;
509            }
510            let time = f64::from_le_bytes(data[pos..pos + 8].try_into().ok()?);
511            pos += 8;
512            let nb = u64::from_le_bytes(data[pos..pos + 8].try_into().ok()?) as usize;
513            pos += 8;
514            let mut frame = TrajectoryFrame::new(time);
515            // positions (nb * 12 bytes f32)
516            for _ in 0..nb {
517                if pos + 12 > data.len() {
518                    break;
519                }
520                let x = f32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as f64;
521                let y = f32::from_le_bytes(data[pos + 4..pos + 8].try_into().ok()?) as f64;
522                let z = f32::from_le_bytes(data[pos + 8..pos + 12].try_into().ok()?) as f64;
523                frame.positions.push([x, y, z]);
524                pos += 12;
525            }
526            // velocities
527            for _ in 0..nb {
528                if pos + 12 > data.len() {
529                    break;
530                }
531                let x = f32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as f64;
532                let y = f32::from_le_bytes(data[pos + 4..pos + 8].try_into().ok()?) as f64;
533                let z = f32::from_le_bytes(data[pos + 8..pos + 12].try_into().ok()?) as f64;
534                frame.velocities.push([x, y, z]);
535                pos += 12;
536            }
537            // quaternions
538            for _ in 0..nb {
539                if pos + 16 > data.len() {
540                    break;
541                }
542                let w = f32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as f64;
543                let x = f32::from_le_bytes(data[pos + 4..pos + 8].try_into().ok()?) as f64;
544                let y = f32::from_le_bytes(data[pos + 8..pos + 12].try_into().ok()?) as f64;
545                let z = f32::from_le_bytes(data[pos + 12..pos + 16].try_into().ok()?) as f64;
546                frame.quaternions.push([w, x, y, z]);
547                pos += 16;
548            }
549            frames.push(frame);
550        }
551        Some(TrajectoryReader { frames, cursor: 0 })
552    }
553
554    /// Seek to a given frame index.
555    pub fn seek(&mut self, frame_idx: usize) {
556        self.cursor = frame_idx.min(self.frames.len());
557    }
558
559    /// Return the next frame, or None if at end.
560    pub fn next_frame(&mut self) -> Option<&TrajectoryFrame> {
561        if self.cursor < self.frames.len() {
562            let f = &self.frames[self.cursor];
563            self.cursor += 1;
564            Some(f)
565        } else {
566            None
567        }
568    }
569
570    /// Total number of frames.
571    pub fn num_frames(&self) -> usize {
572        self.frames.len()
573    }
574}
575
576// ---------------------------------------------------------------------------
577// ContactForceLog
578// ---------------------------------------------------------------------------
579
580/// A single contact force record.
581#[derive(Debug, Clone)]
582pub struct ContactForceRecord {
583    /// Timestep index.
584    pub step: u64,
585    /// Body A id.
586    pub body_a: u64,
587    /// Body B id.
588    pub body_b: u64,
589    /// Normal impulse magnitude.
590    pub normal_impulse: f64,
591    /// Tangential impulse magnitude.
592    pub tangent_impulse: f64,
593    /// Contact point.
594    pub point: [f64; 3],
595    /// Contact normal.
596    pub normal: [f64; 3],
597}
598
599/// Log of contact forces at each timestep.
600pub struct ContactForceLog {
601    records: Vec<ContactForceRecord>,
602}
603
604impl ContactForceLog {
605    /// Construct an empty log.
606    pub fn new() -> Self {
607        ContactForceLog {
608            records: Vec::new(),
609        }
610    }
611
612    /// Add a contact force record.
613    pub fn push(&mut self, rec: ContactForceRecord) {
614        self.records.push(rec);
615    }
616
617    /// Get all records for a given step.
618    pub fn records_at_step(&self, step: u64) -> Vec<&ContactForceRecord> {
619        self.records.iter().filter(|r| r.step == step).collect()
620    }
621
622    /// Write to CSV.
623    pub fn write_csv<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
624        writeln!(
625            w,
626            "step,body_a,body_b,normal_impulse,tangent_impulse,px,py,pz,nx,ny,nz"
627        )?;
628        for r in &self.records {
629            writeln!(
630                w,
631                "{},{},{},{},{},{},{},{},{},{},{}",
632                r.step,
633                r.body_a,
634                r.body_b,
635                r.normal_impulse,
636                r.tangent_impulse,
637                r.point[0],
638                r.point[1],
639                r.point[2],
640                r.normal[0],
641                r.normal[1],
642                r.normal[2],
643            )?;
644        }
645        Ok(())
646    }
647
648    /// Total number of records.
649    pub fn len(&self) -> usize {
650        self.records.len()
651    }
652
653    /// Returns true if the log is empty.
654    pub fn is_empty(&self) -> bool {
655        self.records.is_empty()
656    }
657}
658
659impl Default for ContactForceLog {
660    fn default() -> Self {
661        Self::new()
662    }
663}
664
665// ---------------------------------------------------------------------------
666// EnergyLog
667// ---------------------------------------------------------------------------
668
669/// Energy log entry for one timestep.
670#[derive(Debug, Clone)]
671pub struct EnergyEntry {
672    /// Simulation time.
673    pub time: f64,
674    /// Total kinetic energy.
675    pub kinetic: f64,
676    /// Total potential energy.
677    pub potential: f64,
678    /// Total energy.
679    pub total: f64,
680    /// Linear momentum magnitude.
681    pub linear_momentum: f64,
682    /// Angular momentum magnitude.
683    pub angular_momentum: f64,
684}
685
686/// Log kinetic/potential/total energy, angular momentum, linear momentum.
687pub struct EnergyLog {
688    entries: Vec<EnergyEntry>,
689}
690
691impl EnergyLog {
692    /// Construct an empty energy log.
693    pub fn new() -> Self {
694        EnergyLog {
695            entries: Vec::new(),
696        }
697    }
698
699    /// Add an entry.
700    pub fn push(&mut self, entry: EnergyEntry) {
701        self.entries.push(entry);
702    }
703
704    /// Write to CSV.
705    pub fn write_csv<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
706        writeln!(
707            w,
708            "time,kinetic,potential,total,linear_momentum,angular_momentum"
709        )?;
710        for e in &self.entries {
711            writeln!(
712                w,
713                "{},{},{},{},{},{}",
714                e.time, e.kinetic, e.potential, e.total, e.linear_momentum, e.angular_momentum
715            )?;
716        }
717        Ok(())
718    }
719
720    /// Number of entries.
721    pub fn len(&self) -> usize {
722        self.entries.len()
723    }
724
725    /// Returns true if empty.
726    pub fn is_empty(&self) -> bool {
727        self.entries.is_empty()
728    }
729
730    /// Maximum total energy recorded.
731    pub fn max_total_energy(&self) -> f64 {
732        self.entries
733            .iter()
734            .map(|e| e.total)
735            .fold(f64::NEG_INFINITY, f64::max)
736    }
737}
738
739impl Default for EnergyLog {
740    fn default() -> Self {
741        Self::new()
742    }
743}
744
745// ---------------------------------------------------------------------------
746// CheckpointWriter
747// ---------------------------------------------------------------------------
748
749/// A simulation checkpoint: full state including bodies, constraints, and
750/// broadphase.
751pub struct CheckpointWriter {
752    /// Bodies in the checkpoint.
753    pub bodies: Vec<BodyDesc>,
754    /// Simulation time.
755    pub sim_time: f64,
756    /// Simulation step count.
757    pub step: u64,
758    /// Extra metadata key-value pairs.
759    pub metadata: Vec<(String, String)>,
760}
761
762impl CheckpointWriter {
763    /// Construct a checkpoint.
764    pub fn new(sim_time: f64, step: u64) -> Self {
765        CheckpointWriter {
766            bodies: Vec::new(),
767            sim_time,
768            step,
769            metadata: Vec::new(),
770        }
771    }
772
773    /// Add a body.
774    pub fn add_body(&mut self, body: BodyDesc) {
775        self.bodies.push(body);
776    }
777
778    /// Add metadata.
779    pub fn add_metadata(&mut self, key: &str, value: &str) {
780        self.metadata.push((key.to_string(), value.to_string()));
781    }
782
783    /// Write checkpoint as JSON.
784    pub fn write_json<W: Write>(&self, w: &mut W) -> std::io::Result<()> {
785        writeln!(w, "{{")?;
786        writeln!(w, "  \"sim_time\": {},", self.sim_time)?;
787        writeln!(w, "  \"step\": {},", self.step)?;
788        writeln!(w, "  \"metadata\": {{")?;
789        for (i, (k, v)) in self.metadata.iter().enumerate() {
790            let comma = if i + 1 < self.metadata.len() { "," } else { "" };
791            writeln!(w, "    \"{k}\": \"{v}\"{comma}")?;
792        }
793        writeln!(w, "  }},")?;
794        writeln!(w, "  \"bodies\": [")?;
795        for (i, b) in self.bodies.iter().enumerate() {
796            let comma = if i + 1 < self.bodies.len() { "," } else { "" };
797            writeln!(w, "    {}{}", b.to_json(), comma)?;
798        }
799        writeln!(w, "  ]")?;
800        writeln!(w, "}}")?;
801        Ok(())
802    }
803}
804
805// ---------------------------------------------------------------------------
806// CheckpointReader
807// ---------------------------------------------------------------------------
808
809/// Read a checkpoint and restore simulation state.
810pub struct CheckpointReader;
811
812impl CheckpointReader {
813    /// Parse a JSON checkpoint and return (sim_time, step, bodies).
814    pub fn read_json(json: &str) -> (f64, u64, Vec<BodyDesc>) {
815        let sim_time = parse_f64_line(json, "sim_time").unwrap_or(0.0);
816        let step = parse_u64_line(json, "step").unwrap_or(0);
817        let bodies = PhysicsSceneReader::read_json(json);
818        (sim_time, step, bodies)
819    }
820}
821
822fn parse_f64_line(s: &str, field: &str) -> Option<f64> {
823    let pat = format!("\"{field}\":");
824    for line in s.lines() {
825        let line = line.trim();
826        if line.starts_with(pat.as_str()) {
827            let rest = &line[pat.len()..];
828            let val_str = rest.trim().trim_end_matches(',');
829            return val_str.parse().ok();
830        }
831    }
832    None
833}
834
835fn parse_u64_line(s: &str, field: &str) -> Option<u64> {
836    let pat = format!("\"{field}\":");
837    for line in s.lines() {
838        let line = line.trim();
839        if line.starts_with(pat.as_str()) {
840            let rest = &line[pat.len()..];
841            let val_str = rest.trim().trim_end_matches(',');
842            return val_str.parse().ok();
843        }
844    }
845    None
846}
847
848// ---------------------------------------------------------------------------
849// VtkTrajectory
850// ---------------------------------------------------------------------------
851
852/// Write a time-series VTK output: per-step VTU files + a PVD collection.
853pub struct VtkTrajectory {
854    /// Base directory for output files.
855    pub output_dir: String,
856    /// Base filename prefix.
857    pub prefix: String,
858    entries: Vec<(f64, String)>,
859}
860
861impl VtkTrajectory {
862    /// Construct for the given output directory and filename prefix.
863    pub fn new(output_dir: &str, prefix: &str) -> Self {
864        VtkTrajectory {
865            output_dir: output_dir.to_string(),
866            prefix: prefix.to_string(),
867            entries: Vec::new(),
868        }
869    }
870
871    /// Write a single frame as a VTU file and record it in the PVD collection.
872    pub fn write_frame(&mut self, time: f64, positions: &[[f64; 3]]) -> std::io::Result<()> {
873        let frame_idx = self.entries.len();
874        let vtu_name = format!("{}_{:06}.vtu", self.prefix, frame_idx);
875        let vtu_path = format!("{}/{}", self.output_dir, vtu_name);
876        // Write minimal VTU
877        let file = std::fs::File::create(Path::new(&vtu_path))?;
878        let mut w = BufWriter::new(file);
879        writeln!(w, r#"<?xml version="1.0"?>"#)?;
880        writeln!(w, r#"<VTKFile type="UnstructuredGrid" version="0.1">"#)?;
881        writeln!(w, "  <UnstructuredGrid>")?;
882        writeln!(
883            w,
884            r#"    <Piece NumberOfPoints="{}" NumberOfCells="0">"#,
885            positions.len()
886        )?;
887        writeln!(w, "      <Points>")?;
888        writeln!(
889            w,
890            r#"        <DataArray type="Float64" NumberOfComponents="3" format="ascii">"#
891        )?;
892        for p in positions {
893            writeln!(w, "          {} {} {}", p[0], p[1], p[2])?;
894        }
895        writeln!(w, "        </DataArray>")?;
896        writeln!(w, "      </Points>")?;
897        writeln!(w, "      <Cells/>")?;
898        writeln!(w, "    </Piece>")?;
899        writeln!(w, "  </UnstructuredGrid>")?;
900        writeln!(w, "</VTKFile>")?;
901        w.flush()?;
902        self.entries.push((time, vtu_name));
903        Ok(())
904    }
905
906    /// Write the PVD collection file.
907    pub fn write_pvd(&self) -> std::io::Result<()> {
908        let pvd_path = format!("{}/{}.pvd", self.output_dir, self.prefix);
909        let file = std::fs::File::create(Path::new(&pvd_path))?;
910        let mut w = BufWriter::new(file);
911        let refs: Vec<(f64, &str)> = self.entries.iter().map(|(t, s)| (*t, s.as_str())).collect();
912        write_pvd_collection(&mut w, &refs)?;
913        w.flush()?;
914        Ok(())
915    }
916}
917
918// ---------------------------------------------------------------------------
919// ParaviewXdmf
920// ---------------------------------------------------------------------------
921
922/// XDMF wrapper for HDF5-based trajectory (positions as heavy data).
923pub struct ParaviewXdmf {
924    /// Output XDMF file path.
925    pub xdmf_path: String,
926    /// HDF5 file name (basename only).
927    pub h5_file: String,
928    entries: Vec<(f64, usize)>,
929}
930
931impl ParaviewXdmf {
932    /// Construct for the given XDMF and HDF5 paths.
933    pub fn new(xdmf_path: &str, h5_file: &str) -> Self {
934        ParaviewXdmf {
935            xdmf_path: xdmf_path.to_string(),
936            h5_file: h5_file.to_string(),
937            entries: Vec::new(),
938        }
939    }
940
941    /// Record a timestep with the given number of points.
942    pub fn add_timestep(&mut self, time: f64, n_points: usize) {
943        self.entries.push((time, n_points));
944    }
945
946    /// Write the XDMF file.
947    pub fn write(&self) -> std::io::Result<()> {
948        let file = std::fs::File::create(Path::new(&self.xdmf_path))?;
949        let mut w = BufWriter::new(file);
950        writeln!(w, r#"<?xml version="1.0"?>"#)?;
951        writeln!(w, r#"<!DOCTYPE Xdmf SYSTEM "Xdmf.dtd">"#)?;
952        writeln!(w, r#"<Xdmf Version="2.0">"#)?;
953        writeln!(w, "  <Domain>")?;
954        writeln!(
955            w,
956            r#"    <Grid Name="TimeSeries" GridType="Collection" CollectionType="Temporal">"#
957        )?;
958        for (i, &(t, n)) in self.entries.iter().enumerate() {
959            writeln!(w, r#"      <Grid Name="step_{i}" GridType="Uniform">"#)?;
960            writeln!(w, r#"        <Time Value="{t}"/>"#)?;
961            writeln!(
962                w,
963                r#"        <Topology TopologyType="Polyvertex" NumberOfElements="{n}"/>"#
964            )?;
965            writeln!(w, r#"        <Geometry GeometryType="XYZ">"#)?;
966            writeln!(
967                w,
968                r#"          <DataItem Dimensions="{n} 3" NumberType="Float" Precision="4" Format="HDF" href="{h5}">/step_{i}/positions</DataItem>"#,
969                h5 = self.h5_file,
970                n = n
971            )?;
972            writeln!(w, "        </Geometry>")?;
973            writeln!(w, "      </Grid>")?;
974        }
975        writeln!(w, "    </Grid>")?;
976        writeln!(w, "  </Domain>")?;
977        writeln!(w, "</Xdmf>")?;
978        w.flush()?;
979        Ok(())
980    }
981}
982
983// ---------------------------------------------------------------------------
984// Tests
985// ---------------------------------------------------------------------------
986
987#[cfg(test)]
988mod tests {
989    use super::*;
990
991    // --- encode_quaternion_f32 ---
992
993    #[test]
994    fn encode_quat_identity() {
995        let bytes = encode_quaternion_f32(1.0, 0.0, 0.0, 0.0);
996        let w = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
997        assert!((w - 1.0_f32).abs() < 1e-6, "w={w}");
998    }
999
1000    #[test]
1001    fn encode_quat_roundtrip() {
1002        let bytes = encode_quaternion_f32(0.5, 0.5, 0.5, 0.5);
1003        let w = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
1004        let x = f32::from_le_bytes(bytes[4..8].try_into().unwrap());
1005        assert!((w - 0.5_f32).abs() < 1e-6);
1006        assert!((x - 0.5_f32).abs() < 1e-6);
1007    }
1008
1009    // --- pack_float3_array ---
1010
1011    #[test]
1012    fn pack_float3_correct_length() {
1013        let data = vec![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
1014        let bytes = pack_float3_array(&data);
1015        assert_eq!(bytes.len(), 24); // 2 * 3 * 4 bytes
1016    }
1017
1018    #[test]
1019    fn pack_float3_values() {
1020        let data = vec![[1.0f64, 0.0, 0.0]];
1021        let bytes = pack_float3_array(&data);
1022        let x = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
1023        assert!((x - 1.0_f32).abs() < 1e-6);
1024    }
1025
1026    // --- write_pvd_collection ---
1027
1028    #[test]
1029    fn write_pvd_contains_entries() {
1030        let mut buf = Vec::new();
1031        let entries = vec![(0.0, "step_0.vtu"), (0.1, "step_1.vtu")];
1032        write_pvd_collection(&mut buf, &entries).unwrap();
1033        let s = String::from_utf8(buf).unwrap();
1034        assert!(s.contains("step_0.vtu"));
1035        assert!(s.contains("step_1.vtu"));
1036        assert!(s.contains("Collection"));
1037    }
1038
1039    #[test]
1040    fn write_pvd_valid_xml() {
1041        let mut buf = Vec::new();
1042        write_pvd_collection(&mut buf, &[]).unwrap();
1043        let s = String::from_utf8(buf).unwrap();
1044        assert!(s.contains("</VTKFile>"));
1045    }
1046
1047    // --- read_xdmf_timesteps ---
1048
1049    #[test]
1050    fn read_xdmf_empty() {
1051        let result = read_xdmf_timesteps("");
1052        assert!(result.is_empty());
1053    }
1054
1055    #[test]
1056    fn read_xdmf_parses_time() {
1057        let xml = r#"<Grid Name="step_0"><Time Value="0.5"/><DataItem href="data.h5"/></Grid>"#;
1058        let result = read_xdmf_timesteps(xml);
1059        // Simple parser may or may not find it on one line; just ensure no panic
1060        let _ = result;
1061    }
1062
1063    // --- BodyDesc ---
1064
1065    #[test]
1066    fn body_desc_to_json_contains_id() {
1067        let b = BodyDesc::new(42, [1.0, 2.0, 3.0]);
1068        let j = b.to_json();
1069        assert!(j.contains("\"id\":42"), "json={j}");
1070    }
1071
1072    #[test]
1073    fn body_desc_to_json_contains_pos() {
1074        let b = BodyDesc::new(1, [1.5, 2.5, 3.5]);
1075        let j = b.to_json();
1076        assert!(j.contains("1.5"), "json={j}");
1077    }
1078
1079    // --- PhysicsSceneWriter / Reader ---
1080
1081    #[test]
1082    fn scene_writer_produces_json() {
1083        let mut w = PhysicsSceneWriter::new();
1084        w.add_body(BodyDesc::new(1, [0.0, 1.0, 0.0]));
1085        w.add_body(BodyDesc::new(2, [1.0, 0.0, 0.0]));
1086        let mut buf = Vec::new();
1087        w.write_json(&mut buf).unwrap();
1088        let s = String::from_utf8(buf).unwrap();
1089        assert!(s.contains("\"bodies\""));
1090    }
1091
1092    #[test]
1093    fn scene_reader_roundtrip() {
1094        let mut w = PhysicsSceneWriter::new();
1095        w.add_body(BodyDesc::new(7, [1.0, 2.0, 3.0]));
1096        let mut buf = Vec::new();
1097        w.write_json(&mut buf).unwrap();
1098        let s = String::from_utf8(buf).unwrap();
1099        let bodies = PhysicsSceneReader::read_json(&s);
1100        assert!(!bodies.is_empty(), "should find bodies");
1101    }
1102
1103    #[test]
1104    fn scene_validate_unique_ids() {
1105        let bodies = vec![BodyDesc::new(1, [0.0; 3]), BodyDesc::new(2, [0.0; 3])];
1106        assert!(PhysicsSceneReader::validate(&bodies));
1107    }
1108
1109    #[test]
1110    fn scene_validate_duplicate_ids() {
1111        let bodies = vec![BodyDesc::new(1, [0.0; 3]), BodyDesc::new(1, [0.0; 3])];
1112        assert!(!PhysicsSceneReader::validate(&bodies));
1113    }
1114
1115    // --- TrajectoryWriter / Reader ---
1116
1117    #[test]
1118    fn trajectory_writer_roundtrip() {
1119        let mut tw = TrajectoryWriter::new();
1120        let mut f = TrajectoryFrame::new(0.1);
1121        f.positions.push([1.0, 2.0, 3.0]);
1122        f.velocities.push([0.1, 0.0, 0.0]);
1123        f.quaternions.push([1.0, 0.0, 0.0, 0.0]);
1124        tw.push_frame(f);
1125        let mut buf = Vec::new();
1126        tw.write_binary(&mut buf).unwrap();
1127        let reader = TrajectoryReader::from_bytes(&buf).expect("parse failed");
1128        assert_eq!(reader.num_frames(), 1);
1129    }
1130
1131    #[test]
1132    fn trajectory_reader_seek_and_iterate() {
1133        let mut tw = TrajectoryWriter::new();
1134        for i in 0..3 {
1135            let mut f = TrajectoryFrame::new(i as f64 * 0.1);
1136            f.positions.push([i as f64, 0.0, 0.0]);
1137            f.velocities.push([0.0; 3]);
1138            f.quaternions.push([1.0, 0.0, 0.0, 0.0]);
1139            tw.push_frame(f);
1140        }
1141        let mut buf = Vec::new();
1142        tw.write_binary(&mut buf).unwrap();
1143        let mut reader = TrajectoryReader::from_bytes(&buf).unwrap();
1144        reader.seek(1);
1145        let frame = reader.next_frame().unwrap();
1146        assert!((frame.time - 0.1).abs() < 1e-4, "time={}", frame.time);
1147    }
1148
1149    #[test]
1150    fn trajectory_writer_csv() {
1151        let mut tw = TrajectoryWriter::new();
1152        let mut f = TrajectoryFrame::new(0.0);
1153        f.positions.push([0.0, 0.0, 0.0]);
1154        f.velocities.push([0.0; 3]);
1155        f.quaternions.push([1.0, 0.0, 0.0, 0.0]);
1156        tw.push_frame(f);
1157        let mut buf = Vec::new();
1158        tw.write_csv(&mut buf).unwrap();
1159        let s = String::from_utf8(buf).unwrap();
1160        assert!(s.contains("time,body"));
1161    }
1162
1163    // --- ContactForceLog ---
1164
1165    #[test]
1166    fn contact_force_log_push_and_filter() {
1167        let mut log = ContactForceLog::new();
1168        log.push(ContactForceRecord {
1169            step: 0,
1170            body_a: 1,
1171            body_b: 2,
1172            normal_impulse: 0.5,
1173            tangent_impulse: 0.1,
1174            point: [0.0; 3],
1175            normal: [0.0, 1.0, 0.0],
1176        });
1177        log.push(ContactForceRecord {
1178            step: 1,
1179            body_a: 1,
1180            body_b: 3,
1181            normal_impulse: 0.2,
1182            tangent_impulse: 0.0,
1183            point: [0.0; 3],
1184            normal: [1.0, 0.0, 0.0],
1185        });
1186        assert_eq!(log.records_at_step(0).len(), 1);
1187        assert_eq!(log.records_at_step(1).len(), 1);
1188        assert_eq!(log.records_at_step(2).len(), 0);
1189    }
1190
1191    #[test]
1192    fn contact_force_log_csv() {
1193        let mut log = ContactForceLog::new();
1194        log.push(ContactForceRecord {
1195            step: 0,
1196            body_a: 1,
1197            body_b: 2,
1198            normal_impulse: 1.0,
1199            tangent_impulse: 0.5,
1200            point: [0.0; 3],
1201            normal: [0.0, 1.0, 0.0],
1202        });
1203        let mut buf = Vec::new();
1204        log.write_csv(&mut buf).unwrap();
1205        let s = String::from_utf8(buf).unwrap();
1206        assert!(s.contains("step,body_a"));
1207    }
1208
1209    // --- EnergyLog ---
1210
1211    #[test]
1212    fn energy_log_max_energy() {
1213        let mut log = EnergyLog::new();
1214        log.push(EnergyEntry {
1215            time: 0.0,
1216            kinetic: 1.0,
1217            potential: 2.0,
1218            total: 3.0,
1219            linear_momentum: 0.5,
1220            angular_momentum: 0.2,
1221        });
1222        log.push(EnergyEntry {
1223            time: 0.1,
1224            kinetic: 2.0,
1225            potential: 1.0,
1226            total: 3.0,
1227            linear_momentum: 0.4,
1228            angular_momentum: 0.1,
1229        });
1230        assert!((log.max_total_energy() - 3.0).abs() < 1e-9);
1231    }
1232
1233    #[test]
1234    fn energy_log_csv() {
1235        let mut log = EnergyLog::new();
1236        log.push(EnergyEntry {
1237            time: 0.0,
1238            kinetic: 1.0,
1239            potential: 1.0,
1240            total: 2.0,
1241            linear_momentum: 0.0,
1242            angular_momentum: 0.0,
1243        });
1244        let mut buf = Vec::new();
1245        log.write_csv(&mut buf).unwrap();
1246        let s = String::from_utf8(buf).unwrap();
1247        assert!(s.contains("time,kinetic"));
1248    }
1249
1250    // --- CheckpointWriter / Reader ---
1251
1252    #[test]
1253    fn checkpoint_roundtrip() {
1254        let mut cw = CheckpointWriter::new(1.5, 10);
1255        cw.add_body(BodyDesc::new(1, [0.0, 1.0, 0.0]));
1256        cw.add_metadata("solver", "PGS");
1257        let mut buf = Vec::new();
1258        cw.write_json(&mut buf).unwrap();
1259        let s = String::from_utf8(buf).unwrap();
1260        let (t, step, _bodies) = CheckpointReader::read_json(&s);
1261        assert!((t - 1.5).abs() < 1e-9, "t={t}");
1262        assert_eq!(step, 10);
1263    }
1264
1265    #[test]
1266    fn checkpoint_metadata_in_output() {
1267        let mut cw = CheckpointWriter::new(0.0, 0);
1268        cw.add_metadata("version", "1.0");
1269        let mut buf = Vec::new();
1270        cw.write_json(&mut buf).unwrap();
1271        let s = String::from_utf8(buf).unwrap();
1272        assert!(s.contains("version"));
1273    }
1274
1275    // --- ParaviewXdmf ---
1276
1277    #[test]
1278    fn xdmf_write_produces_xml() {
1279        use std::io::Read;
1280        let dir = std::env::temp_dir();
1281        let xdmf_path = dir.join("test_traj.xdmf");
1282        let mut xdmf = ParaviewXdmf::new(xdmf_path.to_str().unwrap(), "data.h5");
1283        xdmf.add_timestep(0.0, 10);
1284        xdmf.add_timestep(0.1, 10);
1285        xdmf.write().unwrap_or_else(|e| {
1286            let _ = e.into_inner();
1287        });
1288        let mut content = String::new();
1289        std::fs::File::open(&xdmf_path)
1290            .unwrap()
1291            .read_to_string(&mut content)
1292            .unwrap();
1293        assert!(content.contains("Xdmf"));
1294        assert!(content.contains("step_0"));
1295    }
1296}