use crate::db;
use crate::model::{
atom::Atom,
residue::Residue,
structure::Structure,
types::{Element, Point, ResidueCategory, ResiduePosition},
};
use crate::ops::error::Error;
use crate::utils::parallel::*;
use nalgebra::{Matrix3, Rotation3, Vector3};
use std::collections::HashSet;
const CARBOXYL_CO_BOND_LENGTH: f64 = 1.25;
const CARBOXYL_OCO_ANGLE_DEG: f64 = 126.0;
const PHOSPHATE_PO_BOND_LENGTH: f64 = 1.48;
type AlignmentPairs = Vec<(Point, Point)>;
type MissingAtoms = Vec<(String, Element, Point)>;
pub fn repair_structure(structure: &mut Structure) -> Result<(), Error> {
structure
.par_residues_mut()
.filter(|r| r.category == ResidueCategory::Standard)
.try_for_each(repair_residue)
}
fn repair_residue(residue: &mut Residue) -> Result<(), Error> {
let template_name = residue.name.clone();
let template =
db::get_template(&template_name).ok_or_else(|| Error::MissingInternalTemplate {
res_name: template_name.to_string(),
})?;
let status = detect_terminal_status(residue);
let valid_names = build_valid_names(template, &status);
clean_invalid_atoms(residue, &valid_names);
let (align_pairs, missing_atoms) = collect_alignment_data(residue, template, &status);
if align_pairs.is_empty() {
return Err(Error::alignment_failed(
&*residue.name,
residue.id,
"No matching heavy atoms found for alignment",
));
}
let transform = calculate_transform(&align_pairs)?;
synthesize_missing_template_atoms(residue, missing_atoms, &transform);
synthesize_terminal_atoms(residue, &status);
Ok(())
}
struct TerminalStatus {
is_protein_c_term: bool,
is_nucleic_5prime: bool,
has_5prime_phosphate: bool,
}
fn detect_terminal_status(residue: &Residue) -> TerminalStatus {
let is_protein = residue.standard_name.is_some_and(|s| s.is_protein());
let is_nucleic = residue.standard_name.is_some_and(|s| s.is_nucleic());
TerminalStatus {
is_protein_c_term: is_protein && residue.position == ResiduePosition::CTerminal,
is_nucleic_5prime: is_nucleic && residue.position == ResiduePosition::FivePrime,
has_5prime_phosphate: is_nucleic
&& residue.position == ResiduePosition::FivePrime
&& residue.has_atom("P"),
}
}
fn build_valid_names(template: db::TemplateView, status: &TerminalStatus) -> HashSet<String> {
let mut names = HashSet::new();
for (name, _, _) in template.heavy_atoms() {
names.insert(name.to_string());
}
for (name, _, _) in template.hydrogens() {
names.insert(name.to_string());
}
if status.is_protein_c_term {
names.insert("OXT".to_string());
}
if status.is_nucleic_5prime {
if status.has_5prime_phosphate {
names.insert("OP3".to_string());
} else {
names.remove("P");
names.remove("OP1");
names.remove("OP2");
}
}
names
}
fn clean_invalid_atoms(residue: &mut Residue, valid_names: &HashSet<String>) {
let atoms_to_remove: Vec<String> = residue
.atoms()
.iter()
.filter(|a| !valid_names.contains(a.name.as_str()))
.map(|a| a.name.to_string())
.collect();
for name in atoms_to_remove {
residue.remove_atom(&name);
}
}
fn collect_alignment_data(
residue: &Residue,
template: db::TemplateView,
status: &TerminalStatus,
) -> (AlignmentPairs, MissingAtoms) {
let mut align_pairs = Vec::new();
let mut missing_atoms = Vec::new();
for (name, element, tmpl_pos) in template.heavy_atoms() {
if status.is_nucleic_5prime
&& !status.has_5prime_phosphate
&& matches!(name, "P" | "OP1" | "OP2")
{
continue;
}
if let Some(atom) = residue.atom(name) {
align_pairs.push((atom.pos, tmpl_pos));
} else {
missing_atoms.push((name.to_string(), element, tmpl_pos));
}
}
(align_pairs, missing_atoms)
}
struct Transform {
rotation: Matrix3<f64>,
translation: Vector3<f64>,
}
impl Transform {
fn apply(&self, point: Point) -> Point {
Point::from(self.rotation * point.coords + self.translation)
}
}
fn calculate_transform(pairs: &[(Point, Point)]) -> Result<Transform, Error> {
let n = pairs.len();
let center_res = pairs.iter().map(|p| p.0.coords).sum::<Vector3<f64>>() / n as f64;
let center_tmpl = pairs.iter().map(|p| p.1.coords).sum::<Vector3<f64>>() / n as f64;
if n == 1 {
return Ok(Transform {
rotation: Matrix3::identity(),
translation: center_res - center_tmpl,
});
}
if n == 2 {
let v_res = pairs[1].0 - pairs[0].0;
let v_tmpl = pairs[1].1 - pairs[0].1;
let rotation =
Rotation3::rotation_between(&v_tmpl, &v_res).unwrap_or_else(Rotation3::identity);
return Ok(Transform {
rotation: rotation.into_inner(),
translation: center_res - rotation * center_tmpl,
});
}
let mut cov = Matrix3::zeros();
for (p_res, p_tmpl) in pairs {
let v_res = p_res.coords - center_res;
let v_tmpl = p_tmpl.coords - center_tmpl;
cov += v_res * v_tmpl.transpose();
}
let svd = cov.svd(true, true);
let u = svd.u.ok_or_else(|| {
Error::alignment_failed("", 0, "SVD decomposition failed: U matrix unavailable")
})?;
let v_t = svd.v_t.ok_or_else(|| {
Error::alignment_failed("", 0, "SVD decomposition failed: V^T matrix unavailable")
})?;
let mut rotation = u * v_t;
if rotation.determinant() < 0.0 {
let mut correction = Matrix3::identity();
correction[(2, 2)] = -1.0;
rotation = u * correction * v_t;
}
Ok(Transform {
rotation,
translation: center_res - rotation * center_tmpl,
})
}
fn synthesize_missing_template_atoms(
residue: &mut Residue,
missing_atoms: Vec<(String, Element, Point)>,
transform: &Transform,
) {
for (name, element, tmpl_pos) in missing_atoms {
let new_pos = transform.apply(tmpl_pos);
residue.add_atom(Atom::new(&name, element, new_pos));
}
}
fn synthesize_terminal_atoms(residue: &mut Residue, status: &TerminalStatus) {
if status.is_protein_c_term {
synthesize_oxt(residue);
}
if status.has_5prime_phosphate {
synthesize_op3(residue);
}
}
fn synthesize_oxt(residue: &mut Residue) {
if residue.has_atom("OXT") {
return;
}
let (c, o, ca) = match (residue.atom("C"), residue.atom("O"), residue.atom("CA")) {
(Some(c), Some(o), Some(ca)) => (c.pos, o.pos, ca.pos),
_ => return,
};
let v_c_ca = (ca - c).normalize();
let v_c_o = (o - c).normalize();
let z = -v_c_ca;
let o_proj = v_c_o - z * v_c_o.dot(&z);
if o_proj.norm() < 1e-10 {
return;
}
let x = o_proj.normalize();
let y = z.cross(&x);
let cos_theta_o = v_c_o.dot(&z);
let theta_o = cos_theta_o.clamp(-1.0, 1.0).acos();
let target_angle = CARBOXYL_OCO_ANGLE_DEG.to_radians();
let theta_oxt = (target_angle - theta_o).abs();
let phi_oxt = std::f64::consts::PI;
let sin_theta = theta_oxt.sin();
let cos_theta = theta_oxt.cos();
let oxt_local_x = sin_theta * phi_oxt.cos();
let oxt_local_y = sin_theta * phi_oxt.sin();
let oxt_local_z = cos_theta;
let oxt_direction = x * oxt_local_x + y * oxt_local_y + z * oxt_local_z;
let oxt_pos = c + oxt_direction * CARBOXYL_CO_BOND_LENGTH;
residue.add_atom(Atom::new("OXT", Element::O, oxt_pos));
}
fn synthesize_op3(residue: &mut Residue) {
if residue.has_atom("OP3") {
return;
}
let (p, op1, op2, o5) = match (
residue.atom("P"),
residue.atom("OP1"),
residue.atom("OP2"),
residue.atom("O5'"),
) {
(Some(p), Some(op1), Some(op2), Some(o5)) => (p.pos, op1.pos, op2.pos, o5.pos),
_ => return,
};
let centroid = Point::from((op1.coords + op2.coords + o5.coords) / 3.0);
let direction = (p - centroid).normalize();
let op3_pos = p + direction * PHOSPHATE_PO_BOND_LENGTH;
residue.add_atom(Atom::new("OP3", Element::O, op3_pos));
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{
atom::Atom,
chain::Chain,
residue::Residue,
types::{Element, Point, ResidueCategory, ResiduePosition, StandardResidue},
};
fn add_atom_from_template(
residue: &mut Residue,
template: db::TemplateView<'_>,
atom_name: &str,
) {
let (_, element, pos) = template
.heavy_atoms()
.find(|(name, _, _)| *name == atom_name)
.unwrap_or_else(|| panic!("template atom {atom_name} missing"));
residue.add_atom(Atom::new(atom_name, element, pos));
}
fn add_hydrogen_from_template(
residue: &mut Residue,
template: db::TemplateView<'_>,
atom_name: &str,
) {
let (_, pos, _) = template
.hydrogens()
.find(|(name, _, _)| *name == atom_name)
.unwrap_or_else(|| panic!("template hydrogen {atom_name} missing"));
residue.add_atom(Atom::new(atom_name, Element::H, pos));
}
fn standard_residue(name: &str, id: i32, std: StandardResidue) -> Residue {
Residue::new(id, None, name, Some(std), ResidueCategory::Standard)
}
fn distance(a: Point, b: Point) -> f64 {
(a - b).norm()
}
fn angle_deg(a: Point, center: Point, b: Point) -> f64 {
let v1 = (a - center).normalize();
let v2 = (b - center).normalize();
v1.dot(&v2).clamp(-1.0, 1.0).acos().to_degrees()
}
#[test]
fn repair_residue_rebuilds_missing_heavy_atoms_and_cleans_extras() {
let template = db::get_template("ALA").expect("template ALA");
let mut residue = standard_residue("ALA", 1, StandardResidue::ALA);
residue.position = ResiduePosition::Internal;
add_atom_from_template(&mut residue, template, "N");
add_atom_from_template(&mut residue, template, "CA");
add_hydrogen_from_template(&mut residue, template, "HA");
residue.add_atom(Atom::new("FAKE", Element::C, Point::new(5.0, 5.0, 5.0)));
repair_residue(&mut residue).expect("repair succeeds");
for (name, _, _) in template.heavy_atoms() {
assert!(residue.has_atom(name), "missing heavy atom {name}");
}
assert!(
residue.has_atom("HA"),
"valid hydrogen removed unexpectedly"
);
assert!(
!residue.has_atom("FAKE"),
"extraneous atom should be removed"
);
}
#[test]
fn repair_residue_adds_oxt_for_cterm_protein() {
let template = db::get_template("ALA").expect("template ALA");
let mut residue = standard_residue("ALA", 10, StandardResidue::ALA);
residue.position = ResiduePosition::CTerminal;
add_atom_from_template(&mut residue, template, "C");
add_atom_from_template(&mut residue, template, "CA");
add_atom_from_template(&mut residue, template, "O");
repair_residue(&mut residue).expect("repair succeeds");
let oxt = residue.atom("OXT").expect("OXT should be synthesized");
assert_eq!(oxt.element, Element::O);
}
#[test]
fn repair_residue_adds_op3_for_5prime_with_phosphate() {
let template = db::get_template("DA").expect("template DA");
let mut residue = standard_residue("DA", 1, StandardResidue::DA);
residue.position = ResiduePosition::FivePrime;
add_atom_from_template(&mut residue, template, "P");
add_atom_from_template(&mut residue, template, "OP1");
add_atom_from_template(&mut residue, template, "OP2");
add_atom_from_template(&mut residue, template, "O5'");
add_atom_from_template(&mut residue, template, "C5'");
add_atom_from_template(&mut residue, template, "C4'");
repair_residue(&mut residue).expect("repair succeeds");
assert!(residue.has_atom("P"), "phosphorus should be retained");
assert!(residue.has_atom("OP1"), "OP1 should be retained");
assert!(residue.has_atom("OP2"), "OP2 should be retained");
assert!(residue.has_atom("O5'"), "O5' should be retained");
let op3 = residue
.atom("OP3")
.expect("OP3 should be synthesized for 5'-phosphate");
assert_eq!(op3.element, Element::O);
}
#[test]
fn repair_residue_excludes_phosphate_for_5prime_without_p() {
let template = db::get_template("DA").expect("template DA");
let mut residue = standard_residue("DA", 1, StandardResidue::DA);
residue.position = ResiduePosition::FivePrime;
add_atom_from_template(&mut residue, template, "O5'");
add_atom_from_template(&mut residue, template, "C5'");
add_atom_from_template(&mut residue, template, "C4'");
add_atom_from_template(&mut residue, template, "C3'");
repair_residue(&mut residue).expect("repair succeeds");
assert!(!residue.has_atom("P"), "P should not be synthesized");
assert!(!residue.has_atom("OP1"), "OP1 should not be synthesized");
assert!(!residue.has_atom("OP2"), "OP2 should not be synthesized");
assert!(!residue.has_atom("OP3"), "OP3 should not be synthesized");
assert!(residue.has_atom("O5'"), "O5' should be retained");
}
#[test]
fn repair_residue_3prime_nucleic_preserves_o3() {
let template = db::get_template("DA").expect("template DA");
let mut residue = standard_residue("DA", 10, StandardResidue::DA);
residue.position = ResiduePosition::ThreePrime;
add_atom_from_template(&mut residue, template, "C3'");
add_atom_from_template(&mut residue, template, "O3'");
add_atom_from_template(&mut residue, template, "C4'");
add_atom_from_template(&mut residue, template, "C5'");
repair_residue(&mut residue).expect("repair succeeds");
assert!(
residue.has_atom("O3'"),
"O3' should be present for 3' terminal"
);
}
#[test]
fn repair_residue_errors_when_no_alignment_atoms_survive() {
let mut residue = standard_residue("ALA", 2, StandardResidue::ALA);
residue.add_atom(Atom::new("FAKE", Element::C, Point::origin()));
let err = repair_residue(&mut residue).expect_err("should fail without anchors");
match err {
Error::AlignmentFailed { .. } => {}
other => panic!("unexpected error: {other:?}"),
}
}
#[test]
fn repair_structure_updates_standard_residues_only() {
let template = db::get_template("GLY").expect("template GLY");
let mut standard = standard_residue("GLY", 5, StandardResidue::GLY);
add_atom_from_template(&mut standard, template, "N");
add_atom_from_template(&mut standard, template, "CA");
let mut hetero = Residue::new(20, None, "LIG", None, ResidueCategory::Hetero);
hetero.add_atom(Atom::new("XX", Element::C, Point::new(-1.0, 0.0, 0.0)));
let mut chain = Chain::new("A");
chain.add_residue(standard);
chain.add_residue(hetero);
let mut structure = Structure::new();
structure.add_chain(chain);
repair_structure(&mut structure).expect("repair succeeds");
let chain = structure.chain("A").expect("chain A");
let fixed = chain.residue(5, None).unwrap();
for (name, _, _) in template.heavy_atoms() {
assert!(fixed.has_atom(name), "missing atom {name} after repair");
}
let hetero_after = chain.residue(20, None).unwrap();
assert!(
hetero_after.has_atom("XX"),
"hetero residue should remain untouched"
);
}
#[test]
fn oxt_geometry_has_correct_distance_and_angle() {
let template = db::get_template("ALA").expect("template ALA");
let mut residue = standard_residue("ALA", 1, StandardResidue::ALA);
residue.position = ResiduePosition::CTerminal;
for (name, element, pos) in template.heavy_atoms() {
residue.add_atom(Atom::new(name, element, pos));
}
repair_residue(&mut residue).expect("repair succeeds");
let c = residue.atom("C").expect("C").pos;
let ca = residue.atom("CA").expect("CA").pos;
let o = residue.atom("O").expect("O").pos;
let oxt = residue.atom("OXT").expect("OXT").pos;
let c_oxt_dist = distance(c, oxt);
assert!(
(c_oxt_dist - CARBOXYL_CO_BOND_LENGTH).abs() < 0.1,
"C-OXT distance {c_oxt_dist:.3} should be ~{CARBOXYL_CO_BOND_LENGTH} Å"
);
let o_c_oxt_angle = angle_deg(o, c, oxt);
assert!(
(o_c_oxt_angle - CARBOXYL_OCO_ANGLE_DEG).abs() < 3.0,
"O-C-OXT angle {o_c_oxt_angle:.1}° should be ~{CARBOXYL_OCO_ANGLE_DEG}°"
);
let oxt_ca_dist = distance(oxt, ca);
assert!(
oxt_ca_dist > 2.0,
"OXT-CA distance {oxt_ca_dist:.3} should be > 2.0 Å (OXT must be away from CA)"
);
}
#[test]
fn op3_geometry_has_correct_distance_and_tetrahedral_angles() {
let template = db::get_template("DA").expect("template DA");
let mut residue = standard_residue("DA", 1, StandardResidue::DA);
residue.position = ResiduePosition::FivePrime;
for (name, element, pos) in template.heavy_atoms() {
residue.add_atom(Atom::new(name, element, pos));
}
repair_residue(&mut residue).expect("repair succeeds");
let p = residue.atom("P").expect("P").pos;
let op1 = residue.atom("OP1").expect("OP1").pos;
let op2 = residue.atom("OP2").expect("OP2").pos;
let o5 = residue.atom("O5'").expect("O5'").pos;
let op3 = residue.atom("OP3").expect("OP3").pos;
let p_op3_dist = distance(p, op3);
assert!(
(p_op3_dist - PHOSPHATE_PO_BOND_LENGTH).abs() < 0.1,
"P-OP3 distance {p_op3_dist:.3} should be ~{PHOSPHATE_PO_BOND_LENGTH} Å"
);
let tetrahedral_angle = 109.5;
let tolerance = 10.0;
let op1_p_op3 = angle_deg(op1, p, op3);
let op2_p_op3 = angle_deg(op2, p, op3);
let o5_p_op3 = angle_deg(o5, p, op3);
assert!(
(op1_p_op3 - tetrahedral_angle).abs() < tolerance,
"OP1-P-OP3 angle {op1_p_op3:.1}° should be ~{tetrahedral_angle}°"
);
assert!(
(op2_p_op3 - tetrahedral_angle).abs() < tolerance,
"OP2-P-OP3 angle {op2_p_op3:.1}° should be ~{tetrahedral_angle}°"
);
assert!(
(o5_p_op3 - tetrahedral_angle).abs() < tolerance,
"O5'-P-OP3 angle {o5_p_op3:.1}° should be ~{tetrahedral_angle}°"
);
}
}