use super::topology_state::TopologyPy;
use crate::atom::AtomView;
use crate::utils::map_pyarray_to_pos;
use crate::{atom::AtomPy, topology_state::StatePy};
use molar::{AtomLike, AtomMutProvider, State, Topology};
use numpy::{PyArray1, PyArrayLike1, PyArrayMethods, PyUntypedArrayMethods};
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
#[pyclass(name = "Particle", frozen)]
pub(crate) struct ParticlePy {
pub(crate) top: Py<TopologyPy>,
pub(crate) st: Py<StatePy>,
#[pyo3(get)]
pub(crate) id: usize,
}
impl ParticlePy {
pub(crate) fn top(&self) -> &Topology {
self.top.get().inner()
}
pub(crate) fn top_mut(&self) -> &mut Topology {
self.top.get().inner_mut()
}
pub(crate) fn st(&self) -> &State {
self.st.get().inner()
}
pub(crate) fn st_mut(&self) -> &mut State {
self.st.get().inner_mut()
}
}
#[pymethods]
impl ParticlePy {
#[getter(pos)]
fn get_pos<'py>(slf: &'py Bound<'py, Self>) -> Bound<'py, PyArray1<f32>> {
let s = slf.get();
unsafe {
map_pyarray_to_pos(s.st.bind(slf.py()), s.id)
}
}
#[setter(pos)]
fn set_pos(&self, pos: PyArrayLike1<f32>) -> PyResult<()> {
if pos.len() != 3 {
return Err(pyo3::exceptions::PyTypeError::new_err(
"pos must have 3 elements",
));
}
let src = pos.data();
let dst = self.st_mut().coords.as_mut_ptr() as *mut f32;
if src != dst {
unsafe { std::ptr::copy_nonoverlapping(src, dst, 3) };
}
Ok(())
}
#[getter(x)]
fn get_x(&self) -> f32 {
unsafe { self.st().coords.get_unchecked(self.id).x }
}
#[setter(x)]
fn set_x(&self, value: f32) {
unsafe { self.st_mut().coords.get_unchecked_mut(self.id).x = value }
}
#[getter(y)]
fn get_y(&self) -> f32 {
unsafe { self.st().coords.get_unchecked(self.id).y }
}
#[setter(y)]
fn set_y(&self, value: f32) {
unsafe { self.st_mut().coords.get_unchecked_mut(self.id).y = value }
}
#[getter(z)]
fn get_z(&self) -> f32 {
unsafe { self.st().coords.get_unchecked(self.id).z }
}
#[setter(z)]
fn set_z(&self, value: f32) {
unsafe { self.st_mut().coords.get_unchecked_mut(self.id).z = value }
}
fn __repr__(&self) -> String {
let pos = unsafe { self.st().coords.get_unchecked(self.id) };
let a = unsafe { self.top().atoms.get_unchecked(self.id) };
format!(
"Particle(id={}, name='{}', resname='{}', resid={}, pos=[{:.3}, {:.3}, {:.3}])",
self.id, a.name, a.resname, a.resid, pos.x, pos.y, pos.z
)
}
#[getter(atom)]
fn get_atom(slf: &Bound<'_, Self>) -> AtomView {
let s = slf.get();
AtomView { top: s.top.clone_ref(slf.py()), index: s.id }
}
#[setter(atom)]
fn set_atom(&self, arg: &Bound<'_, PyAny>) -> PyResult<()> {
let at = if let Ok(at) = arg.cast::<AtomPy>() {
at.borrow().0.clone()
} else if let Ok(v) = arg.cast::<AtomView>() {
v.borrow().atom()?.clone()
} else {
let ty_name = arg.get_type().name()?.to_string();
return Err(PyTypeError::new_err(format!(
"Invalid argument type {ty_name} in set_atom()"
)));
};
unsafe { *self.top_mut().get_atom_mut_unchecked(self.id) = at };
Ok(())
}
#[getter(name)]
fn get_name(&self) -> String {
unsafe {
self.top()
.atoms
.get_unchecked(self.id)
.name
.as_str()
.to_owned()
}
}
#[setter(name)]
fn set_name(&self, value: &str) {
unsafe {
self.top_mut().atoms.get_unchecked_mut(self.id).set_name(value)
}
}
#[getter(resname)]
fn get_resname(&self) -> String {
unsafe {
self.top()
.atoms
.get_unchecked(self.id)
.resname
.as_str()
.to_owned()
}
}
#[setter(resname)]
fn set_resname(&self, value: &str) {
unsafe {
self.top_mut().atoms.get_unchecked_mut(self.id).set_resname(value)
}
}
#[getter(resid)]
fn get_resid(&self) -> i32 {
unsafe { self.top().atoms.get_unchecked(self.id).resid }
}
#[setter(resid)]
fn set_resid(&self, value: i32) {
unsafe { self.top_mut().atoms.get_unchecked_mut(self.id).resid = value }
}
#[getter(resindex)]
fn get_resindex(&self) -> usize {
unsafe { self.top().atoms.get_unchecked(self.id).resindex }
}
#[setter(resindex)]
fn set_resindex(&self, value: usize) {
unsafe { self.top_mut().atoms.get_unchecked_mut(self.id).resindex = value }
}
#[getter(atomic_number)]
fn get_atomic_number(&self) -> u8 {
unsafe { self.top().atoms.get_unchecked(self.id).atomic_number }
}
#[setter(atomic_number)]
fn set_atomic_number(&self, value: u8) {
unsafe {
self.top_mut()
.atoms
.get_unchecked_mut(self.id)
.atomic_number = value
}
}
#[getter(mass)]
fn get_mass(&self) -> f32 {
unsafe { self.top().atoms.get_unchecked(self.id).mass }
}
#[setter(mass)]
fn set_mass(&self, value: f32) {
unsafe { self.top_mut().atoms.get_unchecked_mut(self.id).mass = value }
}
#[getter(charge)]
fn get_charge(&self) -> f32 {
unsafe { self.top().atoms.get_unchecked(self.id).charge }
}
#[setter(charge)]
fn set_charge(&self, value: f32) {
unsafe { self.top_mut().atoms.get_unchecked_mut(self.id).charge = value }
}
#[getter(type_name)]
fn get_type_name(&self) -> String {
unsafe {
self.top()
.atoms
.get_unchecked(self.id)
.type_name
.as_str()
.to_owned()
}
}
#[setter(type_name)]
fn set_type_name(&self, value: &str) {
unsafe {
self.top_mut()
.atoms
.get_unchecked_mut(self.id)
.set_type_name(value)
}
}
#[getter(type_id)]
fn get_type_id(&self) -> u32 {
unsafe { self.top().atoms.get_unchecked(self.id).type_id }
}
#[setter(type_id)]
fn set_type_id(&self, value: u32) {
unsafe { self.top_mut().atoms.get_unchecked_mut(self.id).type_id = value }
}
#[getter(chain)]
fn get_chain(&self) -> char {
unsafe { self.top().atoms.get_unchecked(self.id).chain }
}
#[setter(chain)]
fn set_chain(&self, value: char) {
unsafe { self.top_mut().atoms.get_unchecked_mut(self.id).chain = value }
}
#[getter(bfactor)]
fn get_bfactor(&self) -> f32 {
unsafe { self.top().atoms.get_unchecked(self.id).bfactor }
}
#[setter(bfactor)]
fn set_bfactor(&self, value: f32) {
unsafe { self.top_mut().atoms.get_unchecked_mut(self.id).bfactor = value }
}
#[getter(occupancy)]
fn get_occupancy(&self) -> f32 {
unsafe { self.top().atoms.get_unchecked(self.id).occupancy }
}
#[setter(occupancy)]
fn set_occupancy(&self, value: f32) {
unsafe {
self.top_mut()
.atoms
.get_unchecked_mut(self.id)
.occupancy = value
}
}
}