Skip to main content

oxiphysics_io/hdf5_io/
trajectory.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Trajectory storage, checkpoints, virtual datasets.
5
6#![allow(dead_code)]
7
8use std::collections::HashMap;
9
10use super::file::Hdf5File;
11use super::types::{AttrValue, ExternalRef, Hdf5Dtype, Hdf5Error, Hdf5Result, Hyperslab};
12
13// ---------------------------------------------------------------------------
14// Trajectory storage
15// ---------------------------------------------------------------------------
16
17/// Multi-frame trajectory container: `(n_atoms x n_frames x 3)`.
18///
19/// Follows the HDF5MD convention used by MDAnalysis / h5md.
20#[derive(Debug, Clone)]
21pub struct TrajectoryStore {
22    /// Number of atoms.
23    pub n_atoms: usize,
24    /// Number of time frames stored.
25    pub n_frames: usize,
26    /// Flat storage: positions\[frame * n_atoms * 3 + atom * 3 + xyz\].
27    pub positions: Vec<f64>,
28    /// Flat storage: velocities (same layout as positions, optional).
29    pub velocities: Vec<f64>,
30    /// Flat storage: forces (same layout, optional).
31    pub forces: Vec<f64>,
32    /// Time value for each frame (in picoseconds).
33    pub times: Vec<f64>,
34    /// Per-frame box vectors \[frame * 9\]: row-major 3x3.
35    pub box_vectors: Vec<f64>,
36    /// Whether velocities are stored.
37    pub has_velocities: bool,
38    /// Whether forces are stored.
39    pub has_forces: bool,
40}
41
42impl TrajectoryStore {
43    /// Create an empty trajectory for `n_atoms` atoms.
44    pub fn new(n_atoms: usize) -> Self {
45        Self {
46            n_atoms,
47            n_frames: 0,
48            positions: Vec::new(),
49            velocities: Vec::new(),
50            forces: Vec::new(),
51            times: Vec::new(),
52            box_vectors: Vec::new(),
53            has_velocities: false,
54            has_forces: false,
55        }
56    }
57
58    /// Append a single frame from a slice of positions `[n_atoms * 3]`.
59    pub fn append_frame(&mut self, time: f64, positions: &[f64], box_vecs: &[f64; 9]) {
60        assert_eq!(
61            positions.len(),
62            self.n_atoms * 3,
63            "append_frame: positions must have n_atoms*3 elements"
64        );
65        self.times.push(time);
66        self.positions.extend_from_slice(positions);
67        self.box_vectors.extend_from_slice(box_vecs);
68        self.n_frames += 1;
69    }
70
71    /// Append velocities for the last frame.
72    ///
73    /// Must be called immediately after `append_frame` and before the next one.
74    pub fn append_velocities(&mut self, velocities: &[f64]) {
75        assert_eq!(velocities.len(), self.n_atoms * 3);
76        self.velocities.extend_from_slice(velocities);
77        self.has_velocities = true;
78    }
79
80    /// Append forces for the last frame.
81    pub fn append_forces(&mut self, forces: &[f64]) {
82        assert_eq!(forces.len(), self.n_atoms * 3);
83        self.forces.extend_from_slice(forces);
84        self.has_forces = true;
85    }
86
87    /// Read back the positions for frame `frame_id`.
88    pub fn read_positions(&self, frame_id: usize) -> Hdf5Result<Vec<[f64; 3]>> {
89        if frame_id >= self.n_frames {
90            return Err(Hdf5Error::NotFound(format!("frame {frame_id}")));
91        }
92        let base = frame_id * self.n_atoms * 3;
93        let out: Vec<[f64; 3]> = (0..self.n_atoms)
94            .map(|a| {
95                let i = base + a * 3;
96                [
97                    self.positions[i],
98                    self.positions[i + 1],
99                    self.positions[i + 2],
100                ]
101            })
102            .collect();
103        Ok(out)
104    }
105
106    /// Read back the velocities for frame `frame_id`.
107    pub fn read_velocities(&self, frame_id: usize) -> Hdf5Result<Vec<[f64; 3]>> {
108        if !self.has_velocities {
109            return Err(Hdf5Error::Generic(
110                "trajectory has no velocities".to_string(),
111            ));
112        }
113        if frame_id >= self.n_frames {
114            return Err(Hdf5Error::NotFound(format!("frame {frame_id}")));
115        }
116        let base = frame_id * self.n_atoms * 3;
117        let out: Vec<[f64; 3]> = (0..self.n_atoms)
118            .map(|a| {
119                let i = base + a * 3;
120                [
121                    self.velocities[i],
122                    self.velocities[i + 1],
123                    self.velocities[i + 2],
124                ]
125            })
126            .collect();
127        Ok(out)
128    }
129
130    /// Return the total trajectory duration in picoseconds.
131    pub fn total_time(&self) -> f64 {
132        self.times.last().copied().unwrap_or(0.0) - self.times.first().copied().unwrap_or(0.0)
133    }
134
135    /// Flush the trajectory into an HDF5 file under `/trajectory`.
136    ///
137    /// Creates two datasets: `positions` of shape `[n_frames, n_atoms, 3]`
138    /// and `time` of shape `[n_frames]`.
139    pub fn flush_to_file(&self, file: &mut Hdf5File) -> Hdf5Result<()> {
140        file.create_group("trajectory")?;
141        file.create_dataset(
142            "trajectory",
143            "positions",
144            vec![self.n_frames, self.n_atoms, 3],
145            Hdf5Dtype::Float64,
146        )?;
147        let ds = file.open_dataset_mut("trajectory", "positions")?;
148        ds.write_f64(&self.positions)?;
149
150        // time dataset
151        let group = file.open_group_mut("trajectory")?;
152        group.create_dataset("time", vec![self.n_frames], Hdf5Dtype::Float64)?;
153        let tds = group.open_dataset_mut("time")?;
154        tds.write_f64(&self.times)?;
155        Ok(())
156    }
157}
158
159// ---------------------------------------------------------------------------
160// Checkpoint / restart
161// ---------------------------------------------------------------------------
162
163/// Simulation checkpoint: serialises/deserialises the key state vectors.
164#[derive(Debug, Clone)]
165pub struct Checkpoint {
166    /// Checkpoint label (e.g. `"step_1000"`).
167    pub label: String,
168    /// Atom positions `[n_atoms * 3]`.
169    pub positions: Vec<f64>,
170    /// Atom velocities `[n_atoms * 3]`.
171    pub velocities: Vec<f64>,
172    /// Simulation step number.
173    pub step: u64,
174    /// Simulation time (ps).
175    pub time: f64,
176    /// Box vectors (9 elements, row-major 3x3).
177    pub box_vectors: [f64; 9],
178    /// Optional additional scalar fields.
179    pub scalars: HashMap<String, f64>,
180}
181
182impl Checkpoint {
183    /// Create a new checkpoint.
184    pub fn new(label: &str, step: u64, time: f64) -> Self {
185        Self {
186            label: label.to_string(),
187            positions: Vec::new(),
188            velocities: Vec::new(),
189            step,
190            time,
191            box_vectors: [0.0; 9],
192            scalars: HashMap::new(),
193        }
194    }
195
196    /// Write this checkpoint into an HDF5 file under `/checkpoints/`label`.
197    pub fn write_to_file(&self, file: &mut Hdf5File) -> Hdf5Result<()> {
198        let base = format!("checkpoints/{}", self.label);
199        file.create_group("checkpoints")?;
200        file.create_group(&base)?;
201
202        let n_atoms = self.positions.len() / 3;
203
204        // positions
205        file.create_dataset(&base, "positions", vec![n_atoms, 3], Hdf5Dtype::Float64)?;
206        file.open_dataset_mut(&base, "positions")?
207            .write_f64(&self.positions)?;
208
209        // velocities
210        file.create_dataset(&base, "velocities", vec![n_atoms, 3], Hdf5Dtype::Float64)?;
211        file.open_dataset_mut(&base, "velocities")?
212            .write_f64(&self.velocities)?;
213
214        // step / time attributes
215        file.set_dataset_attr(
216            &base,
217            "positions",
218            "step",
219            AttrValue::Int32(self.step as i32),
220        )?;
221        file.set_dataset_attr(&base, "positions", "time", AttrValue::Float64(self.time))?;
222
223        // box vectors
224        file.create_dataset(&base, "box", vec![3, 3], Hdf5Dtype::Float64)?;
225        file.open_dataset_mut(&base, "box")?
226            .write_f64(&self.box_vectors)?;
227
228        Ok(())
229    }
230
231    /// Restore a checkpoint from an HDF5 file.
232    pub fn read_from_file(file: &Hdf5File, label: &str) -> Hdf5Result<Self> {
233        let base = format!("checkpoints/{label}");
234        let group = file.open_group(&base)?;
235
236        let pos_ds = group.open_dataset("positions")?;
237        let positions = pos_ds.read_f64()?;
238        let velocities = group.open_dataset("velocities")?.read_f64()?;
239        let box_flat = group.open_dataset("box")?.read_f64()?;
240
241        // Read step attribute
242        let step = match pos_ds.get_attr("step")? {
243            AttrValue::Int32(v) => *v as u64,
244            _ => 0,
245        };
246        let time = match pos_ds.get_attr("time")? {
247            AttrValue::Float64(v) => *v,
248            _ => 0.0,
249        };
250
251        let mut box_vectors = [0.0_f64; 9];
252        for (i, &v) in box_flat.iter().take(9).enumerate() {
253            box_vectors[i] = v;
254        }
255
256        Ok(Self {
257            label: label.to_string(),
258            positions,
259            velocities,
260            step,
261            time,
262            box_vectors,
263            scalars: HashMap::new(),
264        })
265    }
266}
267
268// ---------------------------------------------------------------------------
269// Virtual dataset (multi-file)
270// ---------------------------------------------------------------------------
271
272/// A virtual dataset that concatenates slices from multiple source files.
273#[derive(Debug, Clone)]
274pub struct VirtualDataset {
275    /// Final shape of the combined virtual dataset.
276    pub shape: Vec<usize>,
277    /// Sources in order: `(external_ref, hyperslab_into_vds)`.
278    pub sources: Vec<(ExternalRef, Hyperslab)>,
279    /// Datatype of the virtual dataset.
280    pub dtype: Hdf5Dtype,
281}
282
283impl VirtualDataset {
284    /// Create a new virtual dataset definition.
285    pub fn new(shape: Vec<usize>, dtype: Hdf5Dtype) -> Self {
286        Self {
287            shape,
288            sources: Vec::new(),
289            dtype,
290        }
291    }
292
293    /// Add a source contributing to a slice of the virtual dataset.
294    pub fn add_source(&mut self, ext_ref: ExternalRef, slab: Hyperslab) {
295        self.sources.push((ext_ref, slab));
296    }
297
298    /// Simulate resolving the virtual dataset (returns a description string).
299    pub fn resolve_description(&self) -> String {
300        let src_strs: Vec<String> = self
301            .sources
302            .iter()
303            .map(|(r, s)| {
304                format!(
305                    "{}:{}@offset={}->slab_vol={}",
306                    r.filename,
307                    r.dataset_path,
308                    r.byte_offset,
309                    s.volume()
310                )
311            })
312            .collect();
313        format!(
314            "VDS shape={:?} sources=[{}]",
315            self.shape,
316            src_strs.join(", ")
317        )
318    }
319}
320
321// ---------------------------------------------------------------------------
322// Checkpoint manager
323// ---------------------------------------------------------------------------
324
325/// Manages a rolling window of checkpoints in a single HDF5 file.
326#[derive(Debug, Clone)]
327pub struct CheckpointManager {
328    /// Maximum number of checkpoints to keep.
329    pub max_checkpoints: usize,
330    /// Labels of currently stored checkpoints (oldest first).
331    pub checkpoint_labels: Vec<String>,
332}
333
334impl CheckpointManager {
335    /// Create a new manager with the given capacity.
336    pub fn new(max_checkpoints: usize) -> Self {
337        Self {
338            max_checkpoints,
339            checkpoint_labels: Vec::new(),
340        }
341    }
342
343    /// Write a checkpoint, evicting the oldest if capacity is exceeded.
344    pub fn write(&mut self, file: &mut Hdf5File, ckpt: &Checkpoint) -> Hdf5Result<()> {
345        ckpt.write_to_file(file)?;
346        self.checkpoint_labels.push(ckpt.label.clone());
347        if self.checkpoint_labels.len() > self.max_checkpoints {
348            let _old = self.checkpoint_labels.remove(0);
349            // In a real implementation we would delete the group from the file.
350        }
351        Ok(())
352    }
353
354    /// Read the most recent checkpoint.
355    pub fn read_latest(&self, file: &Hdf5File) -> Hdf5Result<Checkpoint> {
356        let label = self
357            .checkpoint_labels
358            .last()
359            .ok_or_else(|| Hdf5Error::Generic("no checkpoints stored".to_string()))?;
360        Checkpoint::read_from_file(file, label)
361    }
362
363    /// Number of stored checkpoints.
364    pub fn count(&self) -> usize {
365        self.checkpoint_labels.len()
366    }
367}