use std::sync::atomic::{AtomicUsize, Ordering};
use molar::prelude::*;
use numpy::{
nalgebra::{self},
PyArrayMethods,
};
use pyo3::{
exceptions::{PyNotImplementedError, PyTypeError, PyValueError},
prelude::*,
IntoPyObjectExt,
};
mod utils;
use utils::*;
mod atom;
use atom::AtomPy;
mod particle;
use particle::ParticlePy;
mod periodic_box;
use periodic_box::PeriodicBoxPy;
mod file_handler;
use file_handler::{FileHandlerPy, FileStatsPy};
mod system;
use system::SystemPy;
mod selection;
use selection::SelPy;
mod topology_state;
use topology_state::*;
use crate::{
selection::{SelAtomIterator, SelPosIterator, TmpSel},
system::{SysAtomIterator, SysParticleIterator, SysPosIterator},
};
#[pyclass(unsendable, name = "Sasa")]
struct SasaPy(Sasa);
#[pymethods]
impl SasaPy {
#[getter]
fn areas(&self) -> &[f32] {
self.0.areas()
}
#[getter]
fn volumes(&self) -> &[f32] {
self.0.volumes()
}
#[getter]
fn total_area(&self) -> f32 {
self.0.total_area()
}
#[getter]
fn total_volume(&self) -> f32 {
self.0.total_volume()
}
fn __repr__(&self) -> String {
format!(
"Sasa(n={}, total_area={:.3}, total_volume={:.3})",
self.0.areas().len(),
self.0.total_area(),
self.0.total_volume()
)
}
}
#[pyclass]
struct IsometryTransform(nalgebra::IsometryMatrix3<f32>);
#[pymethods]
impl IsometryTransform {
fn __repr__(&self) -> String {
let t = &self.0.translation.vector;
format!(
"IsometryTransform(trans=[{:.3}, {:.3}, {:.3}])",
t[0], t[1], t[2]
)
}
}
#[pyfunction(name = "fit_transform")]
fn fit_transform_py(sel1: &SelPy, sel2: &SelPy) -> PyResult<IsometryTransform> {
let tr = molar::prelude::fit_transform(sel1, sel2).map_err(to_py_runtime_err)?;
Ok(IsometryTransform(tr))
}
#[pyfunction(name = "fit_transform_matching")]
fn fit_transform_matching_py(sel1: &SelPy, sel2: &SelPy) -> PyResult<IsometryTransform> {
let (ind1, ind2) = get_matching_atoms_by_name(sel1, sel2);
let sub1 = ind1
.into_sel_index(sel1, Some(&sel1.index()))
.map_err(to_py_runtime_err)?;
let sub2 = ind2
.into_sel_index(sel2, Some(&sel2.index()))
.map_err(to_py_runtime_err)?;
let sub_sel1 = TmpSel {
top: sel1.r_top(),
st: sel1.r_st(),
index: &sub1,
};
let sub_sel2 = TmpSel {
top: sel2.r_top(),
st: sel2.r_st(),
index: &sub2,
};
let tr = fit_transform(&sub_sel1, &sub_sel2).map_err(to_py_runtime_err)?;
Ok(IsometryTransform(tr))
}
#[pyfunction]
fn rmsd_py(sel1: &SelPy, sel2: &SelPy) -> PyResult<f32> {
Ok(rmsd(sel1, sel2).map_err(to_py_runtime_err)?)
}
#[pyfunction(name = "rmsd_mw")]
fn rmsd_mw_py(sel1: &SelPy, sel2: &SelPy) -> PyResult<f32> {
Ok(rmsd_mw(sel1, sel2).map_err(to_py_runtime_err)?)
}
#[pyclass(frozen)]
struct ParticleIterator {
sel: Py<SelPy>,
cur: AtomicUsize,
}
#[pymethods]
impl ParticleIterator {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __next__(slf: &Bound<'_, Self>) -> Option<ParticlePy> {
let ret = SelPy::__getitem__(slf.get().sel.get(), slf.get().cur.load(Ordering::Relaxed) as isize).ok();
slf.get().cur.fetch_add(1, Ordering::Relaxed);
ret
}
}
#[pyfunction]
#[pyo3(signature = (cutoff,data1,data2=None,dims=None))]
#[pyo3(text_signature = "(cutoff, data1, data2=None, dims=None)")]
fn distance_search<'py>(
py: Python<'py>,
cutoff: &Bound<'py, PyAny>,
data1: &Bound<'py, SelPy>,
data2: Option<&Bound<'py, SelPy>>,
dims: Option<[bool; 3]>,
) -> PyResult<Bound<'py, PyAny>> {
let mut res: Vec<(usize, usize, f32)>;
let dims = dims.unwrap_or([false, false, false]);
let pbc_dims = PbcDims::new(dims[0], dims[1], dims[2]);
let sel1 = data1.borrow();
if let Ok(d) = cutoff.extract::<f32>() {
if let Some(d2) = data2 {
let sel2 = d2.borrow();
if pbc_dims.any() {
res = distance_search_double_pbc(
d,
sel1.iter_pos(),
sel2.iter_pos(),
sel1.iter_index(),
sel2.iter_index(),
&sel1.require_box().unwrap(),
pbc_dims,
);
} else {
res = distance_search_double(
d,
&sel1 as &SelPy,
&sel2 as &SelPy,
sel1.iter_index(),
sel2.iter_index(),
);
}
} else {
if pbc_dims.any() {
res = distance_search_single_pbc(
d,
sel1.iter_pos(),
sel1.iter_index(),
&sel1.require_box().unwrap(),
pbc_dims,
);
} else {
res = distance_search_single(d, &sel1 as &SelPy, sel1.iter_index());
}
}
} else if let Ok(s) = cutoff.extract::<String>() {
if s != "vdw" {
return Err(PyTypeError::new_err(format!("Unknown cutoff type {s}")));
}
let vdw1: Vec<f32> = sel1.iter_atoms().map(|a| a.vdw()).collect();
if sel1.len() != vdw1.len() {
return Err(PyValueError::new_err(format!(
"Size mismatch 1: {} {}",
sel1.len(),
vdw1.len()
)));
}
if let Some(d2) = data2 {
let sel2 = d2.borrow();
let vdw2: Vec<f32> = sel2.iter_atoms().map(|a| a.vdw()).collect();
if sel2.len() != vdw2.len() {
return Err(PyValueError::new_err(format!(
"Size mismatch 2: {} {}",
sel2.len(),
vdw2.len()
)));
}
if pbc_dims.any() {
res = distance_search_double_vdw_pbc(
sel1.iter_pos(),
sel2.iter_pos(),
&vdw1,
&vdw2,
&sel1.require_box().unwrap(),
pbc_dims,
);
} else {
res = distance_search_double_vdw(&sel1 as &SelPy, &sel2 as &SelPy, &vdw1, &vdw2);
}
unsafe {
for el in &mut res {
el.0 = sel1.get_index_unchecked(el.0);
el.1 = sel2.get_index_unchecked(el.1);
}
}
} else {
return Err(PyNotImplementedError::new_err(
"VdW distance search is not yet supported for single selection",
));
}
} else {
return Err(PyTypeError::new_err("cutoff must be a float or 'vdw'"));
};
let n = res.len();
let mut flat_pairs = Vec::with_capacity(n * 2);
let mut dists = Vec::with_capacity(n);
for (i, j, d) in res {
flat_pairs.push(i);
flat_pairs.push(j);
dists.push(d);
}
let pairs_arr = numpy::PyArray1::from_vec(py, flat_pairs).reshape([n, 2])?;
let dist_arr = numpy::PyArray1::from_vec(py, dists);
Ok((pairs_arr, dist_arr).into_bound_py_any(py)?)
}
#[pyclass(name = "NdxFile")]
struct NdxFilePy(NdxFile);
#[pymethods]
impl NdxFilePy {
#[new]
fn new(fname: &str) -> PyResult<Self> {
Ok(NdxFilePy(NdxFile::new(fname).map_err(to_py_value_err)?))
}
fn get_group_as_sel(&self, gr_name: &str, sys: &SystemPy) -> PyResult<SelPy> {
Python::attach(|py| {
Ok(SelPy::new(
sys.py_top().clone_ref(py),
sys.py_st().clone_ref(py),
self.0
.get_group(gr_name)
.map_err(to_py_value_err)?
.to_owned(),
))
})
}
}
#[pyfunction]
fn greeting() {
molar::greeting("molar_python");
}
#[pymodule(name = "molar")]
fn molar_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init();
m.add_class::<AtomPy>()?;
m.add_class::<ParticlePy>()?;
m.add_class::<TopologyPy>()?;
m.add_class::<StatePy>()?;
m.add_class::<PeriodicBoxPy>()?;
m.add_class::<FileHandlerPy>()?;
m.add_class::<FileStatsPy>()?;
m.add_class::<SystemPy>()?;
m.add_class::<SelPy>()?;
m.add_class::<SasaPy>()?;
m.add_class::<NdxFilePy>()?;
m.add_class::<SysPosIterator>()?;
m.add_class::<SysAtomIterator>()?;
m.add_class::<SysParticleIterator>()?;
m.add_class::<SelPosIterator>()?;
m.add_class::<SelAtomIterator>()?;
m.add_function(wrap_pyfunction!(greeting, m)?)?;
m.add_function(wrap_pyfunction!(fit_transform_py, m)?)?;
m.add_function(wrap_pyfunction!(fit_transform_matching_py, m)?)?;
m.add_function(wrap_pyfunction!(rmsd_py, m)?)?;
m.add_function(wrap_pyfunction!(rmsd_mw_py, m)?)?;
m.add_function(wrap_pyfunction!(distance_search, m)?)?;
Ok(())
}