use crate::core::PdbStructure;
use crate::error::PdbError;
use super::transform::{AtomSelection, CoordWithResidue, extract_coords_with_residue_info};
#[derive(Debug, Clone)]
pub struct LddtOptions {
pub inclusion_radius: f64,
pub thresholds: Vec<f64>,
}
impl Default for LddtOptions {
fn default() -> Self {
Self {
inclusion_radius: 15.0,
thresholds: vec![0.5, 1.0, 2.0, 4.0],
}
}
}
impl LddtOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_inclusion_radius(mut self, radius: f64) -> Self {
self.inclusion_radius = radius;
self
}
pub fn with_thresholds(mut self, thresholds: Vec<f64>) -> Self {
self.thresholds = thresholds;
self
}
}
#[derive(Debug, Clone)]
pub struct LddtResult {
pub score: f64,
pub num_pairs: usize,
pub per_threshold_scores: Vec<f64>,
pub num_residues: usize,
}
#[derive(Debug, Clone)]
pub struct PerResidueLddt {
pub residue_id: (String, i32),
pub residue_name: String,
pub score: f64,
pub num_pairs: usize,
}
pub fn calculate_lddt(
model: &PdbStructure,
reference: &PdbStructure,
selection: AtomSelection,
options: LddtOptions,
) -> Result<LddtResult, PdbError> {
let model_coords = extract_coords_with_residue_info(model, &selection, None);
let ref_coords = extract_coords_with_residue_info(reference, &selection, None);
if model_coords.is_empty() {
return Err(PdbError::NoAtomsSelected(format!(
"No atoms matching {:?} selection in model structure",
selection
)));
}
if ref_coords.is_empty() {
return Err(PdbError::NoAtomsSelected(format!(
"No atoms matching {:?} selection in reference structure",
selection
)));
}
if model_coords.len() != ref_coords.len() {
return Err(PdbError::AtomCountMismatch {
expected: ref_coords.len(),
found: model_coords.len(),
});
}
calculate_lddt_from_coords(&model_coords, &ref_coords, &options)
}
pub fn per_residue_lddt(
model: &PdbStructure,
reference: &PdbStructure,
selection: AtomSelection,
options: LddtOptions,
) -> Result<Vec<PerResidueLddt>, PdbError> {
let model_coords = extract_coords_with_residue_info(model, &selection, None);
let ref_coords = extract_coords_with_residue_info(reference, &selection, None);
if model_coords.is_empty() {
return Err(PdbError::NoAtomsSelected(format!(
"No atoms matching {:?} selection in model structure",
selection
)));
}
if ref_coords.is_empty() {
return Err(PdbError::NoAtomsSelected(format!(
"No atoms matching {:?} selection in reference structure",
selection
)));
}
if model_coords.len() != ref_coords.len() {
return Err(PdbError::AtomCountMismatch {
expected: ref_coords.len(),
found: model_coords.len(),
});
}
per_residue_lddt_from_coords(&model_coords, &ref_coords, &options)
}
#[inline]
fn distance(p1: &(f64, f64, f64), p2: &(f64, f64, f64)) -> f64 {
let dx = p1.0 - p2.0;
let dy = p1.1 - p2.1;
let dz = p1.2 - p2.2;
(dx * dx + dy * dy + dz * dz).sqrt()
}
fn calculate_lddt_from_coords(
model_coords: &[CoordWithResidue],
ref_coords: &[CoordWithResidue],
options: &LddtOptions,
) -> Result<LddtResult, PdbError> {
let n = ref_coords.len();
let inclusion_radius_sq = options.inclusion_radius * options.inclusion_radius;
let mut preserved_counts: Vec<usize> = vec![0; options.thresholds.len()];
let mut total_pairs: usize = 0;
let mut unique_residues = std::collections::HashSet::new();
for i in 0..n {
let ref_i = &ref_coords[i];
let model_i = &model_coords[i];
unique_residues.insert((ref_i.0.0.clone(), ref_i.0.1));
for j in (i + 1)..n {
let ref_j = &ref_coords[j];
let model_j = &model_coords[j];
let d_ref = distance(&ref_i.1, &ref_j.1);
if d_ref * d_ref > inclusion_radius_sq {
continue;
}
let d_model = distance(&model_i.1, &model_j.1);
let diff = (d_ref - d_model).abs();
for (k, threshold) in options.thresholds.iter().enumerate() {
if diff < *threshold {
preserved_counts[k] += 1;
}
}
total_pairs += 1;
}
}
let per_threshold_scores: Vec<f64> = if total_pairs > 0 {
preserved_counts
.iter()
.map(|&count| count as f64 / total_pairs as f64)
.collect()
} else {
vec![1.0; options.thresholds.len()] };
let score = if per_threshold_scores.is_empty() {
1.0
} else {
per_threshold_scores.iter().sum::<f64>() / per_threshold_scores.len() as f64
};
Ok(LddtResult {
score,
num_pairs: total_pairs,
per_threshold_scores,
num_residues: unique_residues.len(),
})
}
fn per_residue_lddt_from_coords(
model_coords: &[CoordWithResidue],
ref_coords: &[CoordWithResidue],
options: &LddtOptions,
) -> Result<Vec<PerResidueLddt>, PdbError> {
use std::collections::HashMap;
let n = ref_coords.len();
let inclusion_radius_sq = options.inclusion_radius * options.inclusion_radius;
let num_thresholds = options.thresholds.len();
let mut residue_stats: HashMap<(String, i32), (Vec<usize>, usize, String)> = HashMap::new();
for ref_coord in ref_coords {
let key = (ref_coord.0.0.clone(), ref_coord.0.1);
residue_stats
.entry(key)
.or_insert_with(|| (vec![0; num_thresholds], 0, ref_coord.0.2.clone()));
}
for i in 0..n {
let ref_i = &ref_coords[i];
let model_i = &model_coords[i];
for j in (i + 1)..n {
let ref_j = &ref_coords[j];
let model_j = &model_coords[j];
let d_ref = distance(&ref_i.1, &ref_j.1);
if d_ref * d_ref > inclusion_radius_sq {
continue;
}
let d_model = distance(&model_i.1, &model_j.1);
let diff = (d_ref - d_model).abs();
let key_i = (ref_i.0.0.clone(), ref_i.0.1);
if let Some((preserved, total, _)) = residue_stats.get_mut(&key_i) {
for (k, threshold) in options.thresholds.iter().enumerate() {
if diff < *threshold {
preserved[k] += 1;
}
}
*total += 1;
}
let key_j = (ref_j.0.0.clone(), ref_j.0.1);
if let Some((preserved, total, _)) = residue_stats.get_mut(&key_j) {
for (k, threshold) in options.thresholds.iter().enumerate() {
if diff < *threshold {
preserved[k] += 1;
}
}
*total += 1;
}
}
}
let mut results: Vec<PerResidueLddt> = residue_stats
.into_iter()
.map(|(key, (preserved, total, residue_name))| {
let score = if total > 0 {
let per_threshold: Vec<f64> =
preserved.iter().map(|&p| p as f64 / total as f64).collect();
per_threshold.iter().sum::<f64>() / num_thresholds as f64
} else {
1.0 };
PerResidueLddt {
residue_id: key,
residue_name,
score,
num_pairs: total,
}
})
.collect();
results.sort_by(|a, b| {
a.residue_id
.0
.cmp(&b.residue_id.0)
.then(a.residue_id.1.cmp(&b.residue_id.1))
});
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::records::Atom;
fn create_atom(x: f64, y: f64, z: f64, residue_seq: i32, chain_id: &str) -> Atom {
Atom {
serial: residue_seq,
name: "CA".to_string(),
alt_loc: None,
residue_name: "ALA".to_string(),
chain_id: chain_id.to_string(),
residue_seq,
x,
y,
z,
occupancy: 1.0,
temp_factor: 0.0,
element: "C".to_string(),
ins_code: None,
is_hetatm: false,
}
}
fn create_linear_structure(spacing: f64) -> PdbStructure {
let mut structure = PdbStructure::new();
structure.atoms = vec![
create_atom(0.0, 0.0, 0.0, 1, "A"),
create_atom(spacing, 0.0, 0.0, 2, "A"),
create_atom(spacing * 2.0, 0.0, 0.0, 3, "A"),
create_atom(spacing * 3.0, 0.0, 0.0, 4, "A"),
create_atom(spacing * 4.0, 0.0, 0.0, 5, "A"),
];
structure
}
#[test]
fn test_lddt_self_comparison() {
let structure = create_linear_structure(3.8);
let result = calculate_lddt(
&structure,
&structure,
AtomSelection::CaOnly,
LddtOptions::default(),
)
.unwrap();
assert!(
(result.score - 1.0).abs() < 1e-10,
"Self-LDDT should be 1.0, got {}",
result.score
);
}
#[test]
fn test_lddt_translation_invariance() {
let reference = create_linear_structure(3.8);
let mut model = create_linear_structure(3.8);
for atom in &mut model.atoms {
atom.x += 100.0;
atom.y += 50.0;
atom.z += 25.0;
}
let result = calculate_lddt(
&model,
&reference,
AtomSelection::CaOnly,
LddtOptions::default(),
)
.unwrap();
assert!(
(result.score - 1.0).abs() < 1e-10,
"LDDT should be translation invariant, got {}",
result.score
);
}
#[test]
fn test_lddt_rotation_invariance() {
let reference = create_linear_structure(3.8);
let mut model = create_linear_structure(3.8);
for atom in &mut model.atoms {
let x = atom.x;
let y = atom.y;
atom.x = -y;
atom.y = x;
}
let result = calculate_lddt(
&model,
&reference,
AtomSelection::CaOnly,
LddtOptions::default(),
)
.unwrap();
assert!(
(result.score - 1.0).abs() < 1e-10,
"LDDT should be rotation invariant, got {}",
result.score
);
}
#[test]
fn test_lddt_perturbed_structure() {
let reference = create_linear_structure(3.8);
let mut model = create_linear_structure(3.8);
model.atoms[2].y += 5.0;
let result = calculate_lddt(
&model,
&reference,
AtomSelection::CaOnly,
LddtOptions::default(),
)
.unwrap();
assert!(
result.score < 1.0,
"LDDT should be < 1.0 for perturbed structure, got {}",
result.score
);
assert!(
result.score > 0.0,
"LDDT should be > 0.0 for perturbed structure"
);
}
#[test]
fn test_lddt_custom_options() {
let reference = create_linear_structure(3.8);
let mut model = create_linear_structure(3.8);
model.atoms[2].y += 1.5;
let options_lenient = LddtOptions::default().with_thresholds(vec![2.0, 4.0]);
let options_strict = LddtOptions::default().with_thresholds(vec![0.5, 1.0]);
let result_lenient =
calculate_lddt(&model, &reference, AtomSelection::CaOnly, options_lenient).unwrap();
let result_strict =
calculate_lddt(&model, &reference, AtomSelection::CaOnly, options_strict).unwrap();
assert!(
result_strict.score <= result_lenient.score,
"Stricter thresholds should give lower or equal LDDT: {} vs {}",
result_strict.score,
result_lenient.score
);
}
#[test]
fn test_lddt_inclusion_radius() {
let reference = create_linear_structure(10.0); let mut model = create_linear_structure(10.0);
model.atoms[4].y += 2.0;
let options_small_radius = LddtOptions::default().with_inclusion_radius(5.0);
let options_large_radius = LddtOptions::default().with_inclusion_radius(50.0);
let result_small = calculate_lddt(
&model,
&reference,
AtomSelection::CaOnly,
options_small_radius,
)
.unwrap();
let result_large = calculate_lddt(
&model,
&reference,
AtomSelection::CaOnly,
options_large_radius,
)
.unwrap();
assert!(result_small.num_pairs <= result_large.num_pairs);
}
#[test]
fn test_per_residue_lddt() {
let reference = create_linear_structure(3.8);
let mut model = create_linear_structure(3.8);
model.atoms[2].y += 5.0;
let per_res = per_residue_lddt(
&model,
&reference,
AtomSelection::CaOnly,
LddtOptions::default(),
)
.unwrap();
assert_eq!(per_res.len(), 5, "Should have 5 residues");
let perturbed = per_res.iter().find(|r| r.residue_id.1 == 3).unwrap();
let others: Vec<_> = per_res.iter().filter(|r| r.residue_id.1 != 3).collect();
let avg_others = others.iter().map(|r| r.score).sum::<f64>() / others.len() as f64;
assert!(
perturbed.score < avg_others,
"Perturbed residue should have lower LDDT: {} vs {}",
perturbed.score,
avg_others
);
}
#[test]
fn test_lddt_empty_structure() {
let structure = PdbStructure::new();
let result = calculate_lddt(
&structure,
&structure,
AtomSelection::CaOnly,
LddtOptions::default(),
);
assert!(matches!(result, Err(PdbError::NoAtomsSelected(_))));
}
#[test]
fn test_lddt_mismatched_structures() {
let structure1 = create_linear_structure(3.8);
let mut structure2 = create_linear_structure(3.8);
structure2.atoms.pop();
let result = calculate_lddt(
&structure1,
&structure2,
AtomSelection::CaOnly,
LddtOptions::default(),
);
assert!(matches!(result, Err(PdbError::AtomCountMismatch { .. })));
}
#[test]
fn test_lddt_options_builder() {
let options = LddtOptions::new()
.with_inclusion_radius(10.0)
.with_thresholds(vec![0.25, 0.5, 1.0]);
assert_eq!(options.inclusion_radius, 10.0);
assert_eq!(options.thresholds, vec![0.25, 0.5, 1.0]);
}
#[test]
fn test_distance_function() {
let p1 = (0.0, 0.0, 0.0);
let p2 = (3.0, 4.0, 0.0);
let d = distance(&p1, &p2);
assert!((d - 5.0).abs() < 1e-10, "Distance should be 5.0, got {}", d);
}
}