1#![allow(dead_code)]
7
8use std::collections::HashMap;
9
10use super::file::Hdf5File;
11use super::types::{AttrValue, ExternalRef, Hdf5Dtype, Hdf5Error, Hdf5Result, Hyperslab};
12
13#[derive(Debug, Clone)]
21pub struct TrajectoryStore {
22 pub n_atoms: usize,
24 pub n_frames: usize,
26 pub positions: Vec<f64>,
28 pub velocities: Vec<f64>,
30 pub forces: Vec<f64>,
32 pub times: Vec<f64>,
34 pub box_vectors: Vec<f64>,
36 pub has_velocities: bool,
38 pub has_forces: bool,
40}
41
42impl TrajectoryStore {
43 pub fn new(n_atoms: usize) -> Self {
45 Self {
46 n_atoms,
47 n_frames: 0,
48 positions: Vec::new(),
49 velocities: Vec::new(),
50 forces: Vec::new(),
51 times: Vec::new(),
52 box_vectors: Vec::new(),
53 has_velocities: false,
54 has_forces: false,
55 }
56 }
57
58 pub fn append_frame(&mut self, time: f64, positions: &[f64], box_vecs: &[f64; 9]) {
60 assert_eq!(
61 positions.len(),
62 self.n_atoms * 3,
63 "append_frame: positions must have n_atoms*3 elements"
64 );
65 self.times.push(time);
66 self.positions.extend_from_slice(positions);
67 self.box_vectors.extend_from_slice(box_vecs);
68 self.n_frames += 1;
69 }
70
71 pub fn append_velocities(&mut self, velocities: &[f64]) {
75 assert_eq!(velocities.len(), self.n_atoms * 3);
76 self.velocities.extend_from_slice(velocities);
77 self.has_velocities = true;
78 }
79
80 pub fn append_forces(&mut self, forces: &[f64]) {
82 assert_eq!(forces.len(), self.n_atoms * 3);
83 self.forces.extend_from_slice(forces);
84 self.has_forces = true;
85 }
86
87 pub fn read_positions(&self, frame_id: usize) -> Hdf5Result<Vec<[f64; 3]>> {
89 if frame_id >= self.n_frames {
90 return Err(Hdf5Error::NotFound(format!("frame {frame_id}")));
91 }
92 let base = frame_id * self.n_atoms * 3;
93 let out: Vec<[f64; 3]> = (0..self.n_atoms)
94 .map(|a| {
95 let i = base + a * 3;
96 [
97 self.positions[i],
98 self.positions[i + 1],
99 self.positions[i + 2],
100 ]
101 })
102 .collect();
103 Ok(out)
104 }
105
106 pub fn read_velocities(&self, frame_id: usize) -> Hdf5Result<Vec<[f64; 3]>> {
108 if !self.has_velocities {
109 return Err(Hdf5Error::Generic(
110 "trajectory has no velocities".to_string(),
111 ));
112 }
113 if frame_id >= self.n_frames {
114 return Err(Hdf5Error::NotFound(format!("frame {frame_id}")));
115 }
116 let base = frame_id * self.n_atoms * 3;
117 let out: Vec<[f64; 3]> = (0..self.n_atoms)
118 .map(|a| {
119 let i = base + a * 3;
120 [
121 self.velocities[i],
122 self.velocities[i + 1],
123 self.velocities[i + 2],
124 ]
125 })
126 .collect();
127 Ok(out)
128 }
129
130 pub fn total_time(&self) -> f64 {
132 self.times.last().copied().unwrap_or(0.0) - self.times.first().copied().unwrap_or(0.0)
133 }
134
135 pub fn flush_to_file(&self, file: &mut Hdf5File) -> Hdf5Result<()> {
140 file.create_group("trajectory")?;
141 file.create_dataset(
142 "trajectory",
143 "positions",
144 vec![self.n_frames, self.n_atoms, 3],
145 Hdf5Dtype::Float64,
146 )?;
147 let ds = file.open_dataset_mut("trajectory", "positions")?;
148 ds.write_f64(&self.positions)?;
149
150 let group = file.open_group_mut("trajectory")?;
152 group.create_dataset("time", vec![self.n_frames], Hdf5Dtype::Float64)?;
153 let tds = group.open_dataset_mut("time")?;
154 tds.write_f64(&self.times)?;
155 Ok(())
156 }
157}
158
159#[derive(Debug, Clone)]
165pub struct Checkpoint {
166 pub label: String,
168 pub positions: Vec<f64>,
170 pub velocities: Vec<f64>,
172 pub step: u64,
174 pub time: f64,
176 pub box_vectors: [f64; 9],
178 pub scalars: HashMap<String, f64>,
180}
181
182impl Checkpoint {
183 pub fn new(label: &str, step: u64, time: f64) -> Self {
185 Self {
186 label: label.to_string(),
187 positions: Vec::new(),
188 velocities: Vec::new(),
189 step,
190 time,
191 box_vectors: [0.0; 9],
192 scalars: HashMap::new(),
193 }
194 }
195
196 pub fn write_to_file(&self, file: &mut Hdf5File) -> Hdf5Result<()> {
198 let base = format!("checkpoints/{}", self.label);
199 file.create_group("checkpoints")?;
200 file.create_group(&base)?;
201
202 let n_atoms = self.positions.len() / 3;
203
204 file.create_dataset(&base, "positions", vec![n_atoms, 3], Hdf5Dtype::Float64)?;
206 file.open_dataset_mut(&base, "positions")?
207 .write_f64(&self.positions)?;
208
209 file.create_dataset(&base, "velocities", vec![n_atoms, 3], Hdf5Dtype::Float64)?;
211 file.open_dataset_mut(&base, "velocities")?
212 .write_f64(&self.velocities)?;
213
214 file.set_dataset_attr(
216 &base,
217 "positions",
218 "step",
219 AttrValue::Int32(self.step as i32),
220 )?;
221 file.set_dataset_attr(&base, "positions", "time", AttrValue::Float64(self.time))?;
222
223 file.create_dataset(&base, "box", vec![3, 3], Hdf5Dtype::Float64)?;
225 file.open_dataset_mut(&base, "box")?
226 .write_f64(&self.box_vectors)?;
227
228 Ok(())
229 }
230
231 pub fn read_from_file(file: &Hdf5File, label: &str) -> Hdf5Result<Self> {
233 let base = format!("checkpoints/{label}");
234 let group = file.open_group(&base)?;
235
236 let pos_ds = group.open_dataset("positions")?;
237 let positions = pos_ds.read_f64()?;
238 let velocities = group.open_dataset("velocities")?.read_f64()?;
239 let box_flat = group.open_dataset("box")?.read_f64()?;
240
241 let step = match pos_ds.get_attr("step")? {
243 AttrValue::Int32(v) => *v as u64,
244 _ => 0,
245 };
246 let time = match pos_ds.get_attr("time")? {
247 AttrValue::Float64(v) => *v,
248 _ => 0.0,
249 };
250
251 let mut box_vectors = [0.0_f64; 9];
252 for (i, &v) in box_flat.iter().take(9).enumerate() {
253 box_vectors[i] = v;
254 }
255
256 Ok(Self {
257 label: label.to_string(),
258 positions,
259 velocities,
260 step,
261 time,
262 box_vectors,
263 scalars: HashMap::new(),
264 })
265 }
266}
267
268#[derive(Debug, Clone)]
274pub struct VirtualDataset {
275 pub shape: Vec<usize>,
277 pub sources: Vec<(ExternalRef, Hyperslab)>,
279 pub dtype: Hdf5Dtype,
281}
282
283impl VirtualDataset {
284 pub fn new(shape: Vec<usize>, dtype: Hdf5Dtype) -> Self {
286 Self {
287 shape,
288 sources: Vec::new(),
289 dtype,
290 }
291 }
292
293 pub fn add_source(&mut self, ext_ref: ExternalRef, slab: Hyperslab) {
295 self.sources.push((ext_ref, slab));
296 }
297
298 pub fn resolve_description(&self) -> String {
300 let src_strs: Vec<String> = self
301 .sources
302 .iter()
303 .map(|(r, s)| {
304 format!(
305 "{}:{}@offset={}->slab_vol={}",
306 r.filename,
307 r.dataset_path,
308 r.byte_offset,
309 s.volume()
310 )
311 })
312 .collect();
313 format!(
314 "VDS shape={:?} sources=[{}]",
315 self.shape,
316 src_strs.join(", ")
317 )
318 }
319}
320
321#[derive(Debug, Clone)]
327pub struct CheckpointManager {
328 pub max_checkpoints: usize,
330 pub checkpoint_labels: Vec<String>,
332}
333
334impl CheckpointManager {
335 pub fn new(max_checkpoints: usize) -> Self {
337 Self {
338 max_checkpoints,
339 checkpoint_labels: Vec::new(),
340 }
341 }
342
343 pub fn write(&mut self, file: &mut Hdf5File, ckpt: &Checkpoint) -> Hdf5Result<()> {
345 ckpt.write_to_file(file)?;
346 self.checkpoint_labels.push(ckpt.label.clone());
347 if self.checkpoint_labels.len() > self.max_checkpoints {
348 let _old = self.checkpoint_labels.remove(0);
349 }
351 Ok(())
352 }
353
354 pub fn read_latest(&self, file: &Hdf5File) -> Hdf5Result<Checkpoint> {
356 let label = self
357 .checkpoint_labels
358 .last()
359 .ok_or_else(|| Hdf5Error::Generic("no checkpoints stored".to_string()))?;
360 Checkpoint::read_from_file(file, label)
361 }
362
363 pub fn count(&self) -> usize {
365 self.checkpoint_labels.len()
366 }
367}