use std::{ffi::CString, path::Path, sync::Arc};
use molar_gromacs::{TprAtom, TprBond, TprHandle, TprMolecule, TprPlugin};
use nalgebra::Matrix3;
use thiserror::Error;
use crate::atom::{AtomStr, ATOM_NAME_EXPECT, ATOM_RESNAME_EXPECT, ATOM_TYPE_NAME_EXPECT};
use crate::prelude::*;
pub struct TprFileHandler {
plugin: Arc<TprPlugin>,
handle: *mut TprHandle,
already_read: bool,
}
unsafe impl Send for TprFileHandler {}
#[derive(Debug, Error)]
pub enum TprHandlerError {
#[error("Gromacs plugin not found (is MOLAR_GROMACS_PLUGIN set correctly?): {0}")]
GromacsNotFound(String),
#[error("failed to open TPR file: {0}")]
OpenFailed(String),
#[error("unexpected null character in path")]
CStringNull(#[from] std::ffi::NulError),
#[error("invalid periodic box")]
Pbc(#[from] PeriodicBoxError),
}
impl Drop for TprFileHandler {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe { (self.plugin.fns.tpr_close)(self.handle) };
self.handle = std::ptr::null_mut();
}
}
}
impl FileFormatHandler for TprFileHandler {
fn open(fname: impl AsRef<Path>) -> Result<Self, FileFormatError>
where
Self: Sized,
{
let plugin = TprPlugin::get_cached()
.map_err(|e| TprHandlerError::GromacsNotFound(e.to_string()))?;
let c_path = CString::new(fname.as_ref().to_str().unwrap())
.map_err(TprHandlerError::CStringNull)?;
let handle = unsafe { (plugin.fns.tpr_open)(c_path.as_ptr()) };
if handle.is_null() {
let msg = plugin.last_error();
return Err(TprHandlerError::OpenFailed(msg).into());
}
Ok(TprFileHandler { plugin, handle, already_read: false })
}
fn read(&mut self) -> Result<(Topology, State), FileFormatError> {
if self.already_read {
return Err(FileFormatError::Eof);
}
self.already_read = true;
let fns = &self.plugin.fns;
let h = self.handle;
let natoms = unsafe { (fns.tpr_natoms)(h) };
let nbonds = unsafe { (fns.tpr_nbonds)(h) };
let nmols = unsafe { (fns.tpr_nmolecules)(h) };
let mut atoms_buf: Vec<TprAtom> = Vec::with_capacity(natoms);
let mut bonds_buf: Vec<TprBond> = Vec::with_capacity(nbonds);
let mut mols_buf: Vec<TprMolecule> = Vec::with_capacity(nmols);
let mut coords_buf: Vec<f32> = Vec::with_capacity(natoms * 3);
let mut box_buf = std::mem::MaybeUninit::<[f32; 9]>::uninit();
unsafe {
atoms_buf.set_len(natoms);
bonds_buf.set_len(nbonds);
mols_buf.set_len(nmols);
coords_buf.set_len(natoms * 3);
(fns.tpr_fill_atoms)(h, atoms_buf.as_mut_ptr());
(fns.tpr_fill_bonds)(h, bonds_buf.as_mut_ptr());
(fns.tpr_fill_molecules)(h, mols_buf.as_mut_ptr());
(fns.tpr_fill_coords)(h, coords_buf.as_mut_ptr());
(fns.tpr_fill_box)(h, box_buf.as_mut_ptr() as *mut f32);
}
let box_buf = unsafe { box_buf.assume_init() };
let mut top = Topology::default();
top.atoms.reserve(natoms);
for a in &atoms_buf {
let chain = if a.chain == 0 || a.chain == b'\0' { ' ' } else { a.chain as char };
top.atoms.push(Atom {
name: atom_str_from_buf(&a.name, ATOM_NAME_EXPECT),
resname: atom_str_from_buf(&a.resname, ATOM_RESNAME_EXPECT),
type_name: atom_str_from_buf(&a.type_name, ATOM_TYPE_NAME_EXPECT),
chain,
resid: a.resind as i32,
type_id: a.type_id,
atomic_number: a.atomic_number as u8,
charge: a.charge,
mass: a.mass,
occupancy: a.occupancy,
bfactor: a.bfactor,
..Default::default()
});
}
for b in &bonds_buf {
top.bonds.push([b.atom1 as usize, b.atom2 as usize]);
}
for m in &mols_buf {
top.molecules.push([m.start as usize, m.end as usize]);
}
top.assign_resindex();
let mut st = State::default();
st.coords.resize(natoms, Default::default());
for i in 0..natoms {
st.coords[i].coords.copy_from_slice(&coords_buf[i * 3..i * 3 + 3]);
}
let m = Matrix3::from_column_slice(&box_buf);
st.pbox = Some(PeriodicBox::from_matrix(m).map_err(TprHandlerError::Pbc)?);
Ok((top, st))
}
}
fn atom_str_from_buf(buf: &[u8; 8], expect: &str) -> AtomStr {
let len = buf.iter().position(|&b| b == 0).unwrap_or(8);
let s = std::str::from_utf8(&buf[..len]).expect("Gromacs atom strings are ASCII");
AtomStr::try_from_str(s).expect(expect)
}
#[cfg(test)]
mod tests {
use crate::io::TprFileHandler;
use crate::prelude::*;
#[test]
fn test_tpr() {
let mut h = match TprFileHandler::open("tests/topol.tpr") {
Ok(h) => h,
Err(e) => {
let is_not_found = std::iter::successors(
Some(&e as &dyn std::error::Error),
|e| e.source(),
)
.any(|e| e.to_string().contains("plugin not found"));
if is_not_found {
eprintln!("Skipping test_tpr: Gromacs plugin not available");
return;
}
panic!("unexpected error: {e}");
}
};
let (top, st) = h.read().unwrap();
println!("natoms: {:?}", top.len());
println!("nbonds: {:?}", BondProvider::num_bonds(&top));
println!("nmolecules: {:?}", MolProvider::num_molecules(&top));
println!("state sz: {:?}", st.len());
}
}