#![allow(dead_code)]
use std::collections::HashMap;
use super::file::Hdf5File;
use super::types::{AttrValue, ExternalRef, Hdf5Dtype, Hdf5Error, Hdf5Result, Hyperslab};
#[derive(Debug, Clone)]
pub struct TrajectoryStore {
pub n_atoms: usize,
pub n_frames: usize,
pub positions: Vec<f64>,
pub velocities: Vec<f64>,
pub forces: Vec<f64>,
pub times: Vec<f64>,
pub box_vectors: Vec<f64>,
pub has_velocities: bool,
pub has_forces: bool,
}
impl TrajectoryStore {
pub fn new(n_atoms: usize) -> Self {
Self {
n_atoms,
n_frames: 0,
positions: Vec::new(),
velocities: Vec::new(),
forces: Vec::new(),
times: Vec::new(),
box_vectors: Vec::new(),
has_velocities: false,
has_forces: false,
}
}
pub fn append_frame(&mut self, time: f64, positions: &[f64], box_vecs: &[f64; 9]) {
assert_eq!(
positions.len(),
self.n_atoms * 3,
"append_frame: positions must have n_atoms*3 elements"
);
self.times.push(time);
self.positions.extend_from_slice(positions);
self.box_vectors.extend_from_slice(box_vecs);
self.n_frames += 1;
}
pub fn append_velocities(&mut self, velocities: &[f64]) {
assert_eq!(velocities.len(), self.n_atoms * 3);
self.velocities.extend_from_slice(velocities);
self.has_velocities = true;
}
pub fn append_forces(&mut self, forces: &[f64]) {
assert_eq!(forces.len(), self.n_atoms * 3);
self.forces.extend_from_slice(forces);
self.has_forces = true;
}
pub fn read_positions(&self, frame_id: usize) -> Hdf5Result<Vec<[f64; 3]>> {
if frame_id >= self.n_frames {
return Err(Hdf5Error::NotFound(format!("frame {frame_id}")));
}
let base = frame_id * self.n_atoms * 3;
let out: Vec<[f64; 3]> = (0..self.n_atoms)
.map(|a| {
let i = base + a * 3;
[
self.positions[i],
self.positions[i + 1],
self.positions[i + 2],
]
})
.collect();
Ok(out)
}
pub fn read_velocities(&self, frame_id: usize) -> Hdf5Result<Vec<[f64; 3]>> {
if !self.has_velocities {
return Err(Hdf5Error::Generic(
"trajectory has no velocities".to_string(),
));
}
if frame_id >= self.n_frames {
return Err(Hdf5Error::NotFound(format!("frame {frame_id}")));
}
let base = frame_id * self.n_atoms * 3;
let out: Vec<[f64; 3]> = (0..self.n_atoms)
.map(|a| {
let i = base + a * 3;
[
self.velocities[i],
self.velocities[i + 1],
self.velocities[i + 2],
]
})
.collect();
Ok(out)
}
pub fn total_time(&self) -> f64 {
self.times.last().copied().unwrap_or(0.0) - self.times.first().copied().unwrap_or(0.0)
}
pub fn flush_to_file(&self, file: &mut Hdf5File) -> Hdf5Result<()> {
file.create_group("trajectory")?;
file.create_dataset(
"trajectory",
"positions",
vec![self.n_frames, self.n_atoms, 3],
Hdf5Dtype::Float64,
)?;
let ds = file.open_dataset_mut("trajectory", "positions")?;
ds.write_f64(&self.positions)?;
let group = file.open_group_mut("trajectory")?;
group.create_dataset("time", vec![self.n_frames], Hdf5Dtype::Float64)?;
let tds = group.open_dataset_mut("time")?;
tds.write_f64(&self.times)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Checkpoint {
pub label: String,
pub positions: Vec<f64>,
pub velocities: Vec<f64>,
pub step: u64,
pub time: f64,
pub box_vectors: [f64; 9],
pub scalars: HashMap<String, f64>,
}
impl Checkpoint {
pub fn new(label: &str, step: u64, time: f64) -> Self {
Self {
label: label.to_string(),
positions: Vec::new(),
velocities: Vec::new(),
step,
time,
box_vectors: [0.0; 9],
scalars: HashMap::new(),
}
}
pub fn write_to_file(&self, file: &mut Hdf5File) -> Hdf5Result<()> {
let base = format!("checkpoints/{}", self.label);
file.create_group("checkpoints")?;
file.create_group(&base)?;
let n_atoms = self.positions.len() / 3;
file.create_dataset(&base, "positions", vec![n_atoms, 3], Hdf5Dtype::Float64)?;
file.open_dataset_mut(&base, "positions")?
.write_f64(&self.positions)?;
file.create_dataset(&base, "velocities", vec![n_atoms, 3], Hdf5Dtype::Float64)?;
file.open_dataset_mut(&base, "velocities")?
.write_f64(&self.velocities)?;
file.set_dataset_attr(
&base,
"positions",
"step",
AttrValue::Int32(self.step as i32),
)?;
file.set_dataset_attr(&base, "positions", "time", AttrValue::Float64(self.time))?;
file.create_dataset(&base, "box", vec![3, 3], Hdf5Dtype::Float64)?;
file.open_dataset_mut(&base, "box")?
.write_f64(&self.box_vectors)?;
Ok(())
}
pub fn read_from_file(file: &Hdf5File, label: &str) -> Hdf5Result<Self> {
let base = format!("checkpoints/{label}");
let group = file.open_group(&base)?;
let pos_ds = group.open_dataset("positions")?;
let positions = pos_ds.read_f64()?;
let velocities = group.open_dataset("velocities")?.read_f64()?;
let box_flat = group.open_dataset("box")?.read_f64()?;
let step = match pos_ds.get_attr("step")? {
AttrValue::Int32(v) => *v as u64,
_ => 0,
};
let time = match pos_ds.get_attr("time")? {
AttrValue::Float64(v) => *v,
_ => 0.0,
};
let mut box_vectors = [0.0_f64; 9];
for (i, &v) in box_flat.iter().take(9).enumerate() {
box_vectors[i] = v;
}
Ok(Self {
label: label.to_string(),
positions,
velocities,
step,
time,
box_vectors,
scalars: HashMap::new(),
})
}
}
#[derive(Debug, Clone)]
pub struct VirtualDataset {
pub shape: Vec<usize>,
pub sources: Vec<(ExternalRef, Hyperslab)>,
pub dtype: Hdf5Dtype,
}
impl VirtualDataset {
pub fn new(shape: Vec<usize>, dtype: Hdf5Dtype) -> Self {
Self {
shape,
sources: Vec::new(),
dtype,
}
}
pub fn add_source(&mut self, ext_ref: ExternalRef, slab: Hyperslab) {
self.sources.push((ext_ref, slab));
}
pub fn resolve_description(&self) -> String {
let src_strs: Vec<String> = self
.sources
.iter()
.map(|(r, s)| {
format!(
"{}:{}@offset={}->slab_vol={}",
r.filename,
r.dataset_path,
r.byte_offset,
s.volume()
)
})
.collect();
format!(
"VDS shape={:?} sources=[{}]",
self.shape,
src_strs.join(", ")
)
}
}
#[derive(Debug, Clone)]
pub struct CheckpointManager {
pub max_checkpoints: usize,
pub checkpoint_labels: Vec<String>,
}
impl CheckpointManager {
pub fn new(max_checkpoints: usize) -> Self {
Self {
max_checkpoints,
checkpoint_labels: Vec::new(),
}
}
pub fn write(&mut self, file: &mut Hdf5File, ckpt: &Checkpoint) -> Hdf5Result<()> {
ckpt.write_to_file(file)?;
self.checkpoint_labels.push(ckpt.label.clone());
if self.checkpoint_labels.len() > self.max_checkpoints {
let _old = self.checkpoint_labels.remove(0);
}
Ok(())
}
pub fn read_latest(&self, file: &Hdf5File) -> Hdf5Result<Checkpoint> {
let label = self
.checkpoint_labels
.last()
.ok_or_else(|| Hdf5Error::Generic("no checkpoints stored".to_string()))?;
Checkpoint::read_from_file(file, label)
}
pub fn count(&self) -> usize {
self.checkpoint_labels.len()
}
}