use crate::prelude::*;
use crate::sasa::Sasa;
use itertools::izip;
use nalgebra::{DVector, IsometryMatrix3, Rotation3, SymmetricEigen, Translation3};
use num_traits::Bounded;
use powersasa::SasaError;
use serde::Deserialize;
use std::f32::consts::PI;
use std::iter::zip;
use thiserror::Error;
pub trait MeasurePos: PosIterProvider + LenProvider {
fn min_max(&self) -> (Pos, Pos) {
let mut lower = Pos::max_value();
let mut upper = Pos::min_value();
for p in self.iter_pos() {
for d in 0..3 {
if p[d] < lower[d] {
lower[d] = p[d]
}
if p[d] > upper[d] {
upper[d] = p[d]
}
}
}
(lower, upper)
}
fn center_of_geometry(&self) -> Pos {
let iter = self.iter_pos();
let n = self.len();
let mut cog = Vector3f::zeros();
for c in iter {
cog += c.coords;
}
Pos::from(cog / n as f32)
}
fn rmsd<S>(&self, other: &S) -> Result<f32, MeasureError>
where
Self: Sized,
S: MeasurePos,
{
super::measure::rmsd(self, other)
}
}
pub fn rmsd<S1, S2>(sel1: &S1, sel2: &S2) -> Result<f32, MeasureError>
where
S1: MeasurePos,
S2: MeasurePos,
{
let mut res = 0.0;
let iter1 = sel1.iter_pos();
let iter2 = sel2.iter_pos();
if sel1.len() != sel2.len() {
return Err(MeasureError::Sizes(sel1.len(), sel2.len()));
}
let n = sel1.len();
for (p1, p2) in std::iter::zip(iter1, iter2) {
res += (p2 - p1).norm_squared();
}
Ok((res / n as f32).sqrt())
}
pub trait MeasureMasses: PosIterProvider + MassIterProvider + LenProvider {
fn center_of_mass(&self) -> Result<Pos, MeasureError> {
let mut cm = Vector3f::zeros();
let mut mass = 0.0;
for (c, m) in zip(self.iter_pos(), self.iter_masses()) {
cm += c.coords * m;
mass += m;
}
if mass == 0.0 {
Err(MeasureError::ZeroMass)
} else {
Ok(Pos::from(cm / mass))
}
}
fn gyration(&self) -> Result<f32, MeasureError> {
let c = self.center_of_mass()?;
Ok(do_gyration(
self.iter_pos().map(|pos| pos - c),
self.iter_masses(),
))
}
fn inertia(&self) -> Result<(Vector3f, Matrix3f), MeasureError> {
let c = self.center_of_mass()?;
Ok(do_inertia(
self.iter_pos().map(|pos| pos - c),
self.iter_masses(),
))
}
fn principal_transform(&self) -> Result<IsometryMatrix3<f32>, MeasureError> {
let c = self.center_of_mass()?;
let (_, axes) = do_inertia(self.iter_pos().map(|pos| pos - c), self.iter_masses());
Ok(do_principal_transform(axes, c.coords))
}
fn fit_transform(
&self,
other: &impl MeasureMasses,
) -> Result<nalgebra::IsometryMatrix3<f32>, MeasureError>
where
Self: Sized,
{
super::measure::fit_transform(self, other)
}
fn fit_transform_at_origin(
&self,
other: &impl MeasureMasses,
) -> Result<nalgebra::IsometryMatrix3<f32>, MeasureError>
where
Self: Sized,
{
super::measure::fit_transform_at_origin(self, other)
}
fn rmsd_mw(&self, other: &impl MeasureMasses) -> Result<f32, MeasureError>
where
Self: Sized,
{
super::measure::rmsd_mw(self, other)
}
}
pub fn fit_transform(
sel1: &impl MeasureMasses,
sel2: &impl MeasureMasses,
) -> Result<nalgebra::IsometryMatrix3<f32>, MeasureError> {
let cm1 = sel1.center_of_mass()?;
let cm2 = sel2.center_of_mass()?;
let rot = rot_transform(
sel1.iter_pos().map(|p| *p - cm1),
sel2.iter_pos().map(|p| *p - cm2),
sel1.iter_masses(),
)?;
Ok(Translation3::from(cm2) * rot * Translation3::from(-cm1))
}
pub fn fit_transform_at_origin(
sel1: &impl MeasureMasses,
sel2: &impl MeasureMasses,
) -> Result<nalgebra::IsometryMatrix3<f32>, MeasureError> {
let rot = rot_transform(
sel1.iter_pos().map(|p| p.coords),
sel2.iter_pos().map(|p| p.coords),
sel1.iter_masses(),
)?;
Ok(nalgebra::convert(rot))
}
pub fn rmsd_mw(sel1: &impl MeasureMasses, sel2: &impl MeasureMasses) -> Result<f32, MeasureError> {
let mut res = 0.0;
let mut m_tot = 0.0;
let iter1 = sel1.iter_pos();
let iter2 = sel2.iter_pos();
if sel1.len() != sel2.len() {
return Err(MeasureError::Sizes(sel1.len(), sel2.len()));
}
for (p1, p2, m) in izip!(iter1, iter2, sel1.iter_masses()) {
res += (p2 - p1).norm_squared() * m;
m_tot += m;
}
if m_tot == 0.0 {
Err(MeasureError::ZeroMass)
} else {
Ok((res / m_tot).sqrt())
}
}
pub trait MeasurePeriodic: PosIterProvider + MassIterProvider + BoxProvider + LenProvider {
fn center_of_mass_pbc(&self) -> Result<Pos, MeasureError> {
let b = self.require_box()?;
let mut pos_iter = self.iter_pos();
let mut mass_iter = self.iter_masses();
let mut mass = mass_iter.next().unwrap();
let p0 = pos_iter.next().unwrap();
let mut cm = p0.coords;
for (c, m) in zip(pos_iter, mass_iter) {
let im = b.closest_image(c, p0).coords;
cm += im * m;
mass += m;
}
if mass == 0.0 {
Err(MeasureError::ZeroMass)
} else {
Ok(Pos::from(cm / mass))
}
}
fn center_of_mass_pbc_dims(&self, dims: PbcDims) -> Result<Pos, MeasureError> {
let b = self.require_box()?;
let mut pos_iter = self.iter_pos();
let mut mass_iter = self.iter_masses();
let mut mass = mass_iter.next().unwrap();
let p0 = pos_iter.next().unwrap();
let mut cm = p0.coords;
for (c, m) in zip(pos_iter, mass_iter) {
let im = b.closest_image_dims(c, p0, dims).coords;
cm += im * m;
mass += m;
}
if mass == 0.0 {
Err(MeasureError::ZeroMass)
} else {
Ok(Pos::from(cm / mass))
}
}
fn center_of_geometry_pbc(&self) -> Result<Pos, MeasureError> {
let b = self.require_box()?;
let mut pos_iter = self.iter_pos();
let p0 = pos_iter.next().unwrap();
let mut cm = p0.coords;
for c in pos_iter {
cm += b.closest_image(c, p0).coords;
}
Ok(Pos::from(cm / self.len() as f32))
}
fn center_of_geometry_pbc_dims(&self, dims: PbcDims) -> Result<Pos, MeasureError> {
let b = self.require_box()?;
let mut pos_iter = self.iter_pos();
let p0 = pos_iter.next().unwrap();
let mut cm = p0.coords;
for c in pos_iter {
cm += b.closest_image_dims(c, p0, dims).coords;
}
Ok(Pos::from(cm / self.len() as f32))
}
fn gyration_pbc(&self) -> Result<f32, MeasureError> {
let b = self.require_box()?;
let c = self.center_of_mass_pbc()?;
Ok(do_gyration(
self.iter_pos().map(|pos| b.shortest_vector(&(pos - c))),
self.iter_masses(),
))
}
fn inertia_pbc(&self) -> Result<(Vector3f, Matrix3f), MeasureError> {
let b = self.require_box()?;
let c = self.center_of_mass_pbc()?;
Ok(do_inertia(
self.iter_pos().map(|pos| b.shortest_vector(&(pos - c))),
self.iter_masses(),
))
}
fn principal_transform_pbc(&self) -> Result<IsometryMatrix3<f32>, MeasureError> {
let b = self.require_box()?;
let c = self.center_of_mass_pbc()?;
let (_, axes) = do_inertia(
self.iter_pos().map(|pos| b.shortest_vector(&(pos - c))),
self.iter_masses(),
);
Ok(do_principal_transform(axes, c.coords))
}
}
fn do_gyration(dists: impl Iterator<Item = Vector3f>, masses: impl Iterator<Item = f32>) -> f32 {
let mut sd = 0.0;
let mut sm = 0.0;
for (d, m) in zip(dists, masses) {
sd += d.norm_squared() * m;
sm += m;
}
(sd / sm).sqrt()
}
fn do_inertia(
dists: impl Iterator<Item = Vector3f>,
masses: impl Iterator<Item = f32>,
) -> (Vector3f, Matrix3f) {
let mut tens = Matrix3f::zeros();
for (d, m) in zip(dists, masses) {
tens[(0, 0)] += m * (d.y * d.y + d.z * d.z);
tens[(1, 1)] += m * (d.x * d.x + d.z * d.z);
tens[(2, 2)] += m * (d.x * d.x + d.y * d.y);
tens[(0, 1)] -= m * d.x * d.y;
tens[(0, 2)] -= m * d.x * d.z;
tens[(1, 2)] -= m * d.y * d.z;
}
tens[(1, 0)] = tens[(0, 1)];
tens[(2, 0)] = tens[(0, 2)];
tens[(2, 1)] = tens[(1, 2)];
let eig = SymmetricEigen::new(tens);
let mut s = eig
.eigenvalues
.into_iter()
.enumerate()
.collect::<Vec<(_, _)>>();
s.sort_unstable_by(|a, b| a.1.partial_cmp(b.1).unwrap());
let moments = Vector3f::new(*s[0].1, *s[1].1, *s[2].1);
let col0 = eig.eigenvectors.column(s[0].0).normalize();
let col1 = eig.eigenvectors.column(s[1].0).normalize();
let col2 = col0.cross(&col1);
let axes = Matrix3f::from_columns(&[col0, col1, col2]);
(moments, axes)
}
fn rot_transform(
pos1: impl Iterator<Item = Vector3f>,
pos2: impl Iterator<Item = Vector3f>,
masses: impl Iterator<Item = f32>,
) -> Result<Rotation3<f32>, MeasureError> {
let mut cov = Matrix3f::zeros();
for (p1, p2, m) in izip!(pos1, pos2, masses) {
cov += p2 * p1.transpose() * m;
}
let svd = nalgebra::SVD::new(cov, true, true);
let u = svd.u.ok_or_else(|| MeasureError::Svd)?;
let v_t = svd.v_t.ok_or_else(|| MeasureError::Svd)?;
let d = if (u * v_t).determinant() < 0.0 {
-1.0
} else {
1.0
};
let mut d_matrix = Matrix3f::identity();
d_matrix[(2, 2)] = d;
Ok(Rotation3::from_matrix_unchecked(u * d_matrix * v_t))
}
fn do_principal_transform(mut axes: Matrix3f, cm: Vector3f) -> IsometryMatrix3<f32> {
axes.try_inverse_mut();
Translation3::from(cm) * Rotation3::from_matrix_unchecked(axes) * Translation3::from(-cm)
}
pub trait MeasureRandomAccess: RandomPosProvider {
fn lipid_tail_order(
&self,
order_type: OrderType,
normals: &Vec<Vector3f>,
bond_orders: &Vec<u8>,
) -> Result<DVector<f32>, LipidOrderError> {
if self.len() < 3 {
return Err(LipidOrderError::TailTooShort(self.len()));
}
if normals.len() != 1 && normals.len() != self.len() - 2 {
return Err(LipidOrderError::NormalsCount(self.len(), self.len() - 2));
}
if bond_orders.len() != self.len() - 1 {
return Err(LipidOrderError::BondOrderCount(self.len(), self.len() - 1));
}
let mut order = DVector::from_element(self.len() - 2, 0.0);
if order_type == OrderType::Sz {
for at in 1..self.len() - 1 {
let v = unsafe { self.get_pos_unchecked(at + 1) - self.get_pos_unchecked(at - 1) };
let normal = if normals.len() == 1 {
&normals[0]
} else {
&normals[at - 1] };
let ang = v.angle(&normal);
order[at - 1] = 1.5 * ang.cos().powi(2) - 0.5;
}
} else {
for i in 0..self.len() - 2 {
if bond_orders[i] == 1 {
if bond_orders[i + 1] == 1 {
let p1 = unsafe { self.get_pos_unchecked(i) };
let p2 = unsafe { self.get_pos_unchecked(i + 1) };
let p3 = unsafe { self.get_pos_unchecked(i + 2) };
let local_z = (p3 - p1).normalize();
let local_x = ((p1 - p2).cross(&(p3 - p2))).normalize();
let local_y = local_x.cross(&local_z);
let n = if normals.len() == 1 {
&normals[0]
} else {
&normals[i] };
let ang_x = local_x.angle(n);
let ang_y = local_y.angle(n);
let sxx = 0.5 * (3.0 * ang_x.cos().powi(2) - 1.0);
let syy = 0.5 * (3.0 * ang_y.cos().powi(2) - 1.0);
order[i] = -(2.0 * sxx + syy) / 3.0;
}
} else {
let c1 = i - 1;
let c2 = i;
let c3 = i + 1;
let c4 = i + 2;
let p1 = unsafe { self.get_pos_unchecked(c1) };
let p2 = unsafe { self.get_pos_unchecked(c2) };
let p3 = unsafe { self.get_pos_unchecked(c3) };
let p4 = unsafe { self.get_pos_unchecked(c4) };
let a1 = 0.5 * (PI - (p1 - p2).angle(&(p3 - p2)));
let a2 = 0.5 * (PI - (p2 - p3).angle(&(p4 - p3)));
let local_z = (p3 - p2).normalize();
let local_x = ((p1 - p2).cross(&local_z)).normalize();
let local_y = local_x.cross(&local_z);
let n1 = if normals.len() == 1 {
&normals[0]
} else {
&normals[i] };
let ang_y = local_y.angle(&n1);
let ang_z = local_z.angle(&n1);
let szz = 0.5 * (3.0 * ang_z.cos().powi(2) - 1.0);
let syy = 0.5 * (3.0 * ang_y.cos().powi(2) - 1.0);
let syz = 1.5 * ang_y.cos() * ang_z.cos();
if order_type == OrderType::ScdCorr {
order[i - 1] = -(a1.cos().powi(2) * syy + a1.sin().powi(2) * szz
- 2.0 * a1.cos() * a1.sin() * syz);
} else {
order[i - 1] = -(szz / 4.0 + 3.0 * syy / 4.0 - 3.0_f32.sqrt() * syz / 2.0);
}
let local_x = ((p3 - p4).cross(&local_z)).normalize();
let local_y = local_x.cross(&local_z);
let n2 = if normals.len() == 1 {
&normals[0]
} else {
&normals[i + 1] };
let ang_y = local_y.angle(n2);
let ang_z = local_z.angle(n2);
let szz = 0.5 * (3.0 * ang_z.cos().powi(2) - 1.0);
let syy = 0.5 * (3.0 * ang_y.cos().powi(2) - 1.0);
let syz = 1.5 * ang_y.cos() * ang_z.cos();
if order_type == OrderType::ScdCorr {
order[i] = -(a2.cos().powi(2) * syy
+ a2.sin().powi(2) * szz
+ 2.0 * a2.cos() * a2.sin() * syz);
} else {
order[i] = -(szz / 4.0 + 3.0 * syy / 4.0 + 3.0_f32.sqrt() * syz / 2.0);
}
} } }
Ok(order)
}
}
pub fn get_matching_atoms_by_name<T1, T2>(seq1: &T1, seq2: &T2) -> (Vec<usize>, Vec<usize>)
where
T1: AtomIterProvider,
T2: AtomIterProvider,
{
use crate::seq_align::*;
let x: Vec<&str> = seq1.iter_atoms().map(|a| a.name.as_str()).collect();
let y: Vec<&str> = seq2.iter_atoms().map(|a| a.name.as_str()).collect();
let score = |a: &&str, b: &&str| if a == b { 1 } else { -1 };
let aln = global_align_affine(&x, &y, -10, -1, score);
let mut matches_x = Vec::new();
let mut matches_y = Vec::new();
let mut sel1_idx = 0;
let mut sel2_idx = 0;
for op in aln.operations.iter() {
match op {
AlignmentOperation::Match => {
matches_x.push(sel1_idx);
matches_y.push(sel2_idx);
sel1_idx += 1;
sel2_idx += 1;
}
AlignmentOperation::Subst => {
sel1_idx += 1;
sel2_idx += 1;
}
AlignmentOperation::Del => sel2_idx += 1, AlignmentOperation::Ins => sel1_idx += 1, }
}
(matches_x, matches_y)
}
pub fn fit_transform_matching(
sel1: &(impl IndexSliceProvider + SelectableBound),
sel2: &(impl IndexSliceProvider + SelectableBound),
) -> Result<nalgebra::IsometryMatrix3<f32>, MeasureError> {
let (ind1, ind2) = get_matching_atoms_by_name(sel1, sel2);
let matched_sel1 = sel1.select_bound(ind1).unwrap();
let matched_sel2 = sel2.select_bound(ind2).unwrap();
fit_transform(&matched_sel1, &matched_sel2)
}
#[derive(PartialEq, Debug, Clone, Deserialize)]
pub enum OrderType {
Sz,
Scd,
ScdCorr,
}
#[derive(Error, Debug)]
pub enum LipidOrderError {
#[error("for {0} tail carbons # of normals should be 1 or {1}")]
NormalsCount(usize, usize),
#[error("for {0} tail carbons # of bond orders should be {1}")]
BondOrderCount(usize, usize),
#[error("tail should have at least 3 carbons, not {0}")]
TailTooShort(usize),
}
pub trait MeasureAtomPos: AtomIterProvider + PosIterProvider + LenProvider {
fn sasa(&self) -> Result<Sasa, MeasureError> {
Sasa::new(self)
}
fn sasa_vol(&self) -> Result<Sasa, MeasureError> {
Sasa::new_with_volume(self)
}
}
#[derive(Error, Debug)]
pub enum MeasureError {
#[error("incompatible sizes: {0} and {1}")]
Sizes(usize, usize),
#[error("zero mass")]
ZeroMass,
#[error("SVD failed")]
Svd,
#[error(transparent)]
Pbc(#[from] PeriodicBoxError),
#[error("can't unwrap disjoint pieces")]
Disjoint,
#[error("lipid order error")]
LipidOrder(#[from] LipidOrderError),
#[error("selection error")]
Sel,
#[error("sasa error")]
Sasa(#[from] SasaError),
}
#[cfg(test)]
mod tests {
use crate::prelude::*;
#[test]
fn test_get_matching_atoms_by_name() -> anyhow::Result<()>{
let sys = System::from_file("tests/albumin.pdb")?;
let sel1 = sys.select_bound("resindex 1:5")?;
let sel2 = sys.select_bound("resindex 1:5 and not ((resindex 1 2 and name CA) or (resindex 3 and name N))")?;
let (ind1, ind2) = get_matching_atoms_by_name(&sel1, &sel2);
let names1: Vec<_> = ind1.iter().map(|i| sel1.get_atom(*i).unwrap().name.clone()).collect();
let names2: Vec<_> = ind2.iter().map(|i| sel2.get_atom(*i).unwrap().name.clone()).collect();
println!("{names1:?}");
println!("{names2:?}");
assert_eq!(names1,names2);
Ok(())
}
}