use crate::prelude::*;
use std::{
fs::File,
io::{self, BufReader, BufWriter, Read, Seek, SeekFrom, Write},
path::Path,
};
use thiserror::Error;
use super::{FileFormatError, FileFormatHandler, SaveState};
#[derive(Error, Debug)]
pub enum TrrHandlerError {
#[error("invalid TRR magic number {0}")]
MagicNumber(i32),
#[error("invalid TRR header")]
Header,
#[error("TRR seek to frame {0} failed")]
SeekFrame(usize, #[source] io::Error),
#[error("TRR seek to time {0} failed")]
SeekTime(f32, #[source] io::Error),
#[error("invalid periodic box")]
Pbc(#[from] PeriodicBoxError),
#[error("unexpected io error")]
Io(#[from] io::Error),
}
fn xdr_read_i32(r: &mut impl Read) -> io::Result<i32> {
let mut buf = [0u8; 4];
r.read_exact(&mut buf)?;
Ok(i32::from_be_bytes(buf))
}
fn xdr_read_f32(r: &mut impl Read) -> io::Result<f32> {
let mut buf = [0u8; 4];
r.read_exact(&mut buf)?;
Ok(f32::from_bits(u32::from_be_bytes(buf)))
}
fn xdr_read_f64(r: &mut impl Read) -> io::Result<f64> {
let mut buf = [0u8; 8];
r.read_exact(&mut buf)?;
Ok(f64::from_bits(u64::from_be_bytes(buf)))
}
fn xdr_write_i32(w: &mut impl Write, v: i32) -> io::Result<()> {
w.write_all(&v.to_be_bytes())
}
fn xdr_write_f32(w: &mut impl Write, v: f32) -> io::Result<()> {
w.write_all(&v.to_bits().to_be_bytes())
}
fn xdr_skip(r: &mut impl Read, n_bytes: usize) -> io::Result<()> {
let mut buf = vec![0u8; n_bytes];
r.read_exact(&mut buf)
}
fn read_xvf(r: &mut impl Read, n: usize, b_double: bool) -> io::Result<Vec<Vel>> {
if b_double {
(0..n)
.map(|_| {
let x = xdr_read_f64(r)? as f32;
let y = xdr_read_f64(r)? as f32;
let z = xdr_read_f64(r)? as f32;
Ok(Vel::new(x, y, z))
})
.collect()
} else {
(0..n)
.map(|_| {
let x = xdr_read_f32(r)?;
let y = xdr_read_f32(r)?;
let z = xdr_read_f32(r)?;
Ok(Vel::new(x, y, z))
})
.collect()
}
}
const GROMACS_MAGIC: i32 = 1993;
const TRR_VERSION: &str = "GMX_trn_file";
struct TrrHeader {
box_size: i32,
vir_size: i32,
pres_size: i32,
x_size: i32,
v_size: i32,
f_size: i32,
natoms: i32,
step: i32,
time: f32,
b_double: bool,
}
fn read_trr_header(r: &mut impl Read) -> Result<TrrHeader, TrrHandlerError> {
let magic = xdr_read_i32(r)?;
if magic != GROMACS_MAGIC {
return Err(TrrHandlerError::MagicNumber(magic));
}
let _slen = xdr_read_i32(r)?;
let str_len = xdr_read_i32(r)? as usize;
let padded = (str_len + 3) & !3;
let mut str_buf = vec![0u8; padded];
r.read_exact(&mut str_buf)?;
let version = std::str::from_utf8(&str_buf[..str_len])
.map_err(|_| TrrHandlerError::Header)?;
if version != TRR_VERSION {
return Err(TrrHandlerError::Header);
}
let ir_size = xdr_read_i32(r)?;
let e_size = xdr_read_i32(r)?;
let box_size = xdr_read_i32(r)?;
let vir_size = xdr_read_i32(r)?;
let pres_size = xdr_read_i32(r)?;
let top_size = xdr_read_i32(r)?;
let sym_size = xdr_read_i32(r)?;
let x_size = xdr_read_i32(r)?;
let v_size = xdr_read_i32(r)?;
let f_size = xdr_read_i32(r)?;
let natoms = xdr_read_i32(r)?;
let step = xdr_read_i32(r)?;
let _nre = xdr_read_i32(r)?;
let n3 = natoms as i32 * 3;
let b_double = (box_size == 72)
|| (x_size == n3 * 8)
|| (v_size != 0 && v_size == n3 * 8)
|| (f_size != 0 && f_size == n3 * 8);
let elem_size = if b_double { 8 } else { 4 };
let time = if b_double {
xdr_read_f64(r)? as f32
} else {
xdr_read_f32(r)?
};
xdr_skip(r, elem_size)?;
let _ = (ir_size, e_size, top_size, sym_size);
Ok(TrrHeader {
box_size,
vir_size,
pres_size,
x_size,
v_size,
f_size,
natoms,
step,
time,
b_double,
})
}
fn write_trr_header(w: &mut impl Write, h: &TrrHeader) -> Result<(), TrrHandlerError> {
xdr_write_i32(w, GROMACS_MAGIC)?;
xdr_write_i32(w, 13)?;
xdr_write_i32(w, TRR_VERSION.len() as i32)?;
w.write_all(TRR_VERSION.as_bytes())?;
xdr_write_i32(w, 0)?; xdr_write_i32(w, 0)?; xdr_write_i32(w, h.box_size)?;
xdr_write_i32(w, 0)?; xdr_write_i32(w, 0)?; xdr_write_i32(w, 0)?; xdr_write_i32(w, 0)?; xdr_write_i32(w, h.x_size)?;
xdr_write_i32(w, h.v_size)?;
xdr_write_i32(w, h.f_size)?;
xdr_write_i32(w, h.natoms)?;
xdr_write_i32(w, h.step)?;
xdr_write_i32(w, 0)?;
xdr_write_f32(w, h.time)?;
xdr_write_f32(w, 0.0)?;
Ok(())
}
fn frame_data_size(h: &TrrHeader) -> usize {
let elem = if h.b_double { 8usize } else { 4usize };
let mut size = 0usize;
if h.box_size != 0 { size += 9 * elem; }
if h.vir_size != 0 { size += 9 * elem; }
if h.pres_size != 0 { size += 9 * elem; }
let n3 = h.natoms as usize * 3;
if h.x_size != 0 { size += n3 * elem; }
if h.v_size != 0 { size += n3 * elem; }
if h.f_size != 0 { size += n3 * elem; }
size
}
pub(crate) struct TrrReader {
file: BufReader<File>,
cur_frame: usize,
}
pub(crate) struct TrrWriter {
file: BufWriter<File>,
cur_frame: usize,
natoms: Option<usize>,
}
pub enum TrrFileHandler {
Reader(TrrReader),
Writer(TrrWriter),
}
fn read_frame_data(
r: &mut impl Read,
h: &TrrHeader,
coords: bool,
velocities: bool,
forces: bool,
) -> Result<State, TrrHandlerError> {
let elem = if h.b_double { 8usize } else { 4usize };
let n = h.natoms as usize;
let pbox = if h.box_size != 0 {
let vals: Vec<f32> = if h.b_double {
(0..9).map(|_| xdr_read_f64(r).map(|v| v as f32)).collect::<io::Result<_>>()?
} else {
(0..9).map(|_| xdr_read_f32(r)).collect::<io::Result<_>>()?
};
let m = Matrix3f::from_iterator(vals.into_iter());
Some(PeriodicBox::from_matrix(m)?)
} else {
None
};
if h.vir_size != 0 {
xdr_skip(r, 9 * elem)?;
}
if h.pres_size != 0 {
xdr_skip(r, 9 * elem)?;
}
let coord_data: Vec<Pos> = if h.x_size != 0 {
if coords {
if h.b_double {
(0..n).map(|_| {
let x = xdr_read_f64(r)? as f32;
let y = xdr_read_f64(r)? as f32;
let z = xdr_read_f64(r)? as f32;
Ok(Pos::new(x, y, z))
}).collect::<io::Result<_>>()?
} else {
(0..n).map(|_| {
let x = xdr_read_f32(r)?;
let y = xdr_read_f32(r)?;
let z = xdr_read_f32(r)?;
Ok(Pos::new(x, y, z))
}).collect::<io::Result<_>>()?
}
} else {
xdr_skip(r, n * 3 * elem)?;
Vec::new()
}
} else {
Vec::new()
};
let vel_data: Vec<Vel> = if h.v_size != 0 {
if velocities {
read_xvf(r, n, h.b_double)?
} else {
xdr_skip(r, n * 3 * elem)?;
Vec::new()
}
} else {
Vec::new()
};
let force_data: Vec<Force> = if h.f_size != 0 {
if forces {
read_xvf(r, n, h.b_double)?
} else {
xdr_skip(r, n * 3 * elem)?;
Vec::new()
}
} else {
Vec::new()
};
Ok(State {
coords: coord_data,
velocities: vel_data,
forces: force_data,
pbox,
time: h.time,
})
}
fn skip_frame_data(r: &mut (impl Read + Seek), h: &TrrHeader) -> io::Result<()> {
let size = frame_data_size(h) as i64;
r.seek(SeekFrom::Current(size))?;
Ok(())
}
impl FileFormatHandler for TrrFileHandler {
fn open(fname: impl AsRef<Path>) -> Result<Self, FileFormatError>
where
Self: Sized,
{
let fname = fname.as_ref();
let file = File::open(fname).map_err(TrrHandlerError::Io)?;
let mut reader = BufReader::new(file);
let h = read_trr_header(&mut reader).map_err(trr_to_ff_err)?;
let natoms = h.natoms as usize;
reader.seek(SeekFrom::Start(0)).map_err(TrrHandlerError::Io)?;
let _ = natoms; Ok(TrrFileHandler::Reader(TrrReader {
file: reader,
cur_frame: 0,
}))
}
fn create(fname: impl AsRef<Path>) -> Result<Self, FileFormatError>
where
Self: Sized,
{
let fname = fname.as_ref();
let file = File::create(fname).map_err(TrrHandlerError::Io)?;
Ok(TrrFileHandler::Writer(TrrWriter {
file: BufWriter::new(file),
cur_frame: 0,
natoms: None,
}))
}
fn read_state(&mut self) -> Result<State, FileFormatError> {
let TrrFileHandler::Reader(ref mut r) = self else {
return Err(FileFormatError::NotStateReadFormat);
};
let h = read_trr_header(&mut r.file).map_err(trr_to_ff_err)?;
let st = read_frame_data(&mut r.file, &h, true, true, true).map_err(trr_to_ff_err)?;
r.cur_frame += 1;
Ok(st)
}
fn read_state_pick(&mut self, coords: bool, velocities: bool, forces: bool) -> Result<State, FileFormatError> {
if !coords {
return Err(FileFormatError::NoCoords);
}
let TrrFileHandler::Reader(ref mut r) = self else {
return Err(FileFormatError::NotStateReadFormat);
};
let h = read_trr_header(&mut r.file).map_err(trr_to_ff_err)?;
let st = read_frame_data(&mut r.file, &h, coords, velocities, forces).map_err(trr_to_ff_err)?;
r.cur_frame += 1;
Ok(st)
}
fn write_state(&mut self, data: &dyn SaveState) -> Result<(), FileFormatError> {
self.write_state_pick(data, true, true, true)
}
fn write_state_pick(&mut self, data: &dyn SaveState, coords: bool, velocities: bool, forces: bool) -> Result<(), FileFormatError> {
let TrrFileHandler::Writer(ref mut w) = self else {
return Err(FileFormatError::NotStateWriteFormat);
};
let natoms = data.len();
if w.natoms.is_none() {
w.natoms = Some(natoms);
}
let has_box = data.get_box().is_some();
let box_size = if has_box { 9 * 4i32 } else { 0i32 };
let x_size = if coords { natoms as i32 * 3 * 4 } else { 0 };
let vel_it = if velocities { data.iter_vel_dyn() } else { Box::new(std::iter::empty()) };
let force_it = if forces { data.iter_force_dyn() } else { Box::new(std::iter::empty()) };
let v_size = if vel_it.len() > 0 { natoms as i32 * 3 * 4 } else { 0 };
let f_size = if force_it.len() > 0 { natoms as i32 * 3 * 4 } else { 0 };
let h = TrrHeader {
box_size,
vir_size: 0,
pres_size: 0,
x_size,
v_size,
f_size,
natoms: natoms as i32,
step: w.cur_frame as i32,
time: data.get_time(),
b_double: false,
};
write_trr_header(&mut w.file, &h)?;
if let Some(b) = data.get_box() {
for &v in b.get_matrix().as_slice() {
xdr_write_f32(&mut w.file, v)?;
}
}
if coords {
for p in data.iter_pos_dyn() {
xdr_write_f32(&mut w.file, p.x)?;
xdr_write_f32(&mut w.file, p.y)?;
xdr_write_f32(&mut w.file, p.z)?;
}
}
for v in vel_it {
xdr_write_f32(&mut w.file, v.x)?;
xdr_write_f32(&mut w.file, v.y)?;
xdr_write_f32(&mut w.file, v.z)?;
}
for f in force_it {
xdr_write_f32(&mut w.file, f.x)?;
xdr_write_f32(&mut w.file, f.y)?;
xdr_write_f32(&mut w.file, f.z)?;
}
w.cur_frame += 1;
Ok(())
}
fn seek_frame(&mut self, fr: usize) -> Result<(), FileFormatError> {
let TrrFileHandler::Reader(ref mut r) = self else {
return Err(FileFormatError::NotRandomAccessFormat);
};
if fr == r.cur_frame {
return Ok(());
}
if fr < r.cur_frame {
r.file
.seek(SeekFrom::Start(0))
.map_err(|e| TrrHandlerError::SeekFrame(fr, e))?;
r.cur_frame = 0;
}
let skip = fr - r.cur_frame;
for _ in 0..skip {
let h = read_trr_header(&mut r.file)
.map_err(|e| match e {
TrrHandlerError::Io(io_e) => TrrHandlerError::Io(io_e),
other => TrrHandlerError::Io(io::Error::new(io::ErrorKind::Other, other.to_string())),
})?;
skip_frame_data(&mut r.file, &h)
.map_err(|e| TrrHandlerError::SeekFrame(fr, e))?;
r.cur_frame += 1;
}
Ok(())
}
fn seek_last(&mut self) -> Result<(), FileFormatError> {
let TrrFileHandler::Reader(ref mut r) = self else {
return Err(FileFormatError::NotRandomAccessFormat);
};
r.file
.seek(SeekFrom::Start(0))
.map_err(TrrHandlerError::Io)?;
r.cur_frame = 0;
let mut last_pos: Option<u64> = None;
let mut last_frame: usize = 0;
let mut frame_idx: usize = 0;
loop {
let pos = r.file.stream_position().map_err(TrrHandlerError::Io)?;
match read_trr_header(&mut r.file) {
Ok(h) => {
last_pos = Some(pos);
last_frame = frame_idx;
match skip_frame_data(&mut r.file, &h) {
Ok(()) => {}
Err(_) => break,
}
frame_idx += 1;
}
Err(TrrHandlerError::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(FileFormatError::from(e)),
}
}
match last_pos {
Some(pos) => {
r.file
.seek(SeekFrom::Start(pos))
.map_err(TrrHandlerError::Io)?;
r.cur_frame = last_frame;
Ok(())
}
None => Err(FileFormatError::Eof),
}
}
fn seek_time(&mut self, t: f32) -> Result<(), FileFormatError> {
let TrrFileHandler::Reader(ref mut r) = self else {
return Err(FileFormatError::NotRandomAccessFormat);
};
r.file
.seek(SeekFrom::Start(0))
.map_err(TrrHandlerError::Io)?;
let mut frame_idx: usize = 0;
loop {
let pos = r.file.stream_position().map_err(TrrHandlerError::Io)?;
match read_trr_header(&mut r.file) {
Ok(h) => {
if h.time >= t {
r.file
.seek(SeekFrom::Start(pos))
.map_err(|e| TrrHandlerError::SeekTime(t, e))?;
r.cur_frame = frame_idx;
return Ok(());
}
skip_frame_data(&mut r.file, &h)
.map_err(|e| TrrHandlerError::SeekTime(t, e))?;
frame_idx += 1;
}
Err(TrrHandlerError::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
return Err(FileFormatError::Eof);
}
Err(e) => return Err(FileFormatError::from(e)),
}
}
}
}
fn trr_to_ff_err(e: TrrHandlerError) -> FileFormatError {
if let TrrHandlerError::Io(ref io_err) = e {
if io_err.kind() == io::ErrorKind::UnexpectedEof {
return FileFormatError::Eof;
}
}
FileFormatError::from(e)
}