use crate::PdbStructure;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ConfidenceCategory {
VeryHigh,
Confident,
Low,
VeryLow,
}
impl ConfidenceCategory {
pub fn from_plddt(plddt: f64) -> Self {
if plddt > 90.0 {
ConfidenceCategory::VeryHigh
} else if plddt >= 70.0 {
ConfidenceCategory::Confident
} else if plddt >= 50.0 {
ConfidenceCategory::Low
} else {
ConfidenceCategory::VeryLow
}
}
pub fn is_reliable(&self) -> bool {
matches!(
self,
ConfidenceCategory::VeryHigh | ConfidenceCategory::Confident
)
}
pub fn needs_caution(&self) -> bool {
matches!(self, ConfidenceCategory::Low | ConfidenceCategory::VeryLow)
}
}
#[derive(Debug, Clone)]
pub struct ResiduePlddt {
pub chain_id: String,
pub residue_seq: i32,
pub residue_name: String,
pub ins_code: Option<char>,
pub plddt: f64,
pub plddt_min: f64,
pub plddt_max: f64,
pub atom_count: usize,
pub confidence_category: ConfidenceCategory,
}
impl ResiduePlddt {
pub fn is_confident(&self) -> bool {
self.confidence_category.is_reliable()
}
pub fn is_disordered(&self) -> bool {
matches!(self.confidence_category, ConfidenceCategory::VeryLow)
}
}
impl PdbStructure {
pub fn is_predicted(&self) -> bool {
let header_check = self
.header
.as_ref()
.map(|h| {
let upper = h.to_uppercase();
upper.contains("ALPHAFOLD")
|| upper.contains("PREDICTED")
|| upper.contains("ESMFOLD")
|| upper.contains("COLABFOLD")
|| upper.contains("ROSETTAFOLD")
})
.unwrap_or(false);
if header_check {
return true;
}
let title_check = self
.title
.as_ref()
.map(|t| {
let upper = t.to_uppercase();
upper.contains("ALPHAFOLD")
|| upper.contains("PREDICTED")
|| upper.contains("ESMFOLD")
})
.unwrap_or(false);
if title_check {
return true;
}
if self.atoms.is_empty() {
return false;
}
let b_factors: Vec<f64> = self.atoms.iter().map(|a| a.temp_factor).collect();
let min_b = b_factors.iter().cloned().fold(f64::MAX, f64::min);
let max_b = b_factors.iter().cloned().fold(f64::MIN, f64::max);
let looks_like_plddt = min_b >= 0.0 && (50.0..=100.5).contains(&max_b);
if looks_like_plddt && b_factors.len() > 100 {
let mean_b: f64 = b_factors.iter().sum::<f64>() / b_factors.len() as f64;
return mean_b > 30.0 && mean_b < 95.0;
}
looks_like_plddt
}
pub fn plddt_scores(&self) -> Vec<f64> {
self.atoms.iter().map(|a| a.temp_factor).collect()
}
pub fn plddt_mean(&self) -> f64 {
if self.atoms.is_empty() {
return 0.0;
}
self.atoms.iter().map(|a| a.temp_factor).sum::<f64>() / self.atoms.len() as f64
}
pub fn per_residue_plddt(&self) -> Vec<ResiduePlddt> {
let mut residue_map: HashMap<(String, i32, Option<char>), Vec<f64>> = HashMap::new();
let mut residue_names: HashMap<(String, i32, Option<char>), String> = HashMap::new();
for atom in &self.atoms {
let key = (atom.chain_id.clone(), atom.residue_seq, atom.ins_code);
residue_map
.entry(key.clone())
.or_default()
.push(atom.temp_factor);
residue_names
.entry(key)
.or_insert_with(|| atom.residue_name.clone());
}
let mut results: Vec<ResiduePlddt> = residue_map
.into_iter()
.map(|((chain_id, residue_seq, ins_code), bfactors)| {
let sum: f64 = bfactors.iter().sum();
let count = bfactors.len();
let mean = sum / count as f64;
let min = bfactors.iter().cloned().fold(f64::MAX, f64::min);
let max = bfactors.iter().cloned().fold(f64::MIN, f64::max);
let residue_name = residue_names
.get(&(chain_id.clone(), residue_seq, ins_code))
.cloned()
.unwrap_or_default();
ResiduePlddt {
chain_id,
residue_seq,
residue_name,
ins_code,
plddt: mean,
plddt_min: min,
plddt_max: max,
atom_count: count,
confidence_category: ConfidenceCategory::from_plddt(mean),
}
})
.collect();
results.sort_by(|a, b| {
a.chain_id
.cmp(&b.chain_id)
.then_with(|| a.residue_seq.cmp(&b.residue_seq))
.then_with(|| a.ins_code.cmp(&b.ins_code))
});
results
}
pub fn low_confidence_regions(&self, threshold: f64) -> Vec<ResiduePlddt> {
self.per_residue_plddt()
.into_iter()
.filter(|res| res.plddt < threshold)
.collect()
}
pub fn high_confidence_regions(&self, threshold: f64) -> Vec<ResiduePlddt> {
self.per_residue_plddt()
.into_iter()
.filter(|res| res.plddt >= threshold)
.collect()
}
pub fn plddt_distribution(&self) -> (f64, f64, f64, f64) {
let residues = self.per_residue_plddt();
if residues.is_empty() {
return (0.0, 0.0, 0.0, 0.0);
}
let total = residues.len() as f64;
let mut very_high = 0;
let mut confident = 0;
let mut low = 0;
let mut very_low = 0;
for res in &residues {
match res.confidence_category {
ConfidenceCategory::VeryHigh => very_high += 1,
ConfidenceCategory::Confident => confident += 1,
ConfidenceCategory::Low => low += 1,
ConfidenceCategory::VeryLow => very_low += 1,
}
}
(
very_high as f64 / total,
confident as f64 / total,
low as f64 / total,
very_low as f64 / total,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::records::Atom;
fn make_atom(chain: &str, resid: i32, res_name: &str, bfactor: f64) -> Atom {
Atom {
serial: resid,
name: "CA".to_string(),
alt_loc: None,
residue_name: res_name.to_string(),
chain_id: chain.to_string(),
residue_seq: resid,
ins_code: None,
is_hetatm: false,
x: 0.0,
y: 0.0,
z: 0.0,
occupancy: 1.0,
temp_factor: bfactor,
element: "C".to_string(),
}
}
#[test]
fn test_confidence_category_from_plddt() {
assert_eq!(
ConfidenceCategory::from_plddt(95.0),
ConfidenceCategory::VeryHigh
);
assert_eq!(
ConfidenceCategory::from_plddt(80.0),
ConfidenceCategory::Confident
);
assert_eq!(
ConfidenceCategory::from_plddt(60.0),
ConfidenceCategory::Low
);
assert_eq!(
ConfidenceCategory::from_plddt(30.0),
ConfidenceCategory::VeryLow
);
}
#[test]
fn test_confidence_category_methods() {
assert!(ConfidenceCategory::VeryHigh.is_reliable());
assert!(ConfidenceCategory::Confident.is_reliable());
assert!(!ConfidenceCategory::Low.is_reliable());
assert!(!ConfidenceCategory::VeryLow.is_reliable());
assert!(!ConfidenceCategory::VeryHigh.needs_caution());
assert!(!ConfidenceCategory::Confident.needs_caution());
assert!(ConfidenceCategory::Low.needs_caution());
assert!(ConfidenceCategory::VeryLow.needs_caution());
}
#[test]
fn test_is_predicted_by_header() {
let mut structure = PdbStructure::new();
structure.header = Some("ALPHAFOLD MODEL".to_string());
assert!(structure.is_predicted());
structure.header = Some("predicted structure".to_string());
assert!(structure.is_predicted());
structure.header = None;
structure.title = Some("ESMFold prediction".to_string());
assert!(structure.is_predicted());
}
#[test]
fn test_is_predicted_by_bfactor() {
let mut structure = PdbStructure::new();
for i in 1..=100 {
structure
.atoms
.push(make_atom("A", i, "ALA", 70.0 + (i as f64 % 20.0)));
}
assert!(structure.is_predicted());
}
#[test]
fn test_is_not_predicted_experimental() {
let mut structure = PdbStructure::new();
for i in 1..=100 {
structure
.atoms
.push(make_atom("A", i, "ALA", 15.0 + (i as f64 % 30.0)));
}
}
#[test]
fn test_plddt_mean() {
let mut structure = PdbStructure::new();
structure.atoms.push(make_atom("A", 1, "ALA", 90.0));
structure.atoms.push(make_atom("A", 2, "GLY", 80.0));
assert!((structure.plddt_mean() - 85.0).abs() < 0.01);
}
#[test]
fn test_plddt_mean_empty() {
let structure = PdbStructure::new();
assert_eq!(structure.plddt_mean(), 0.0);
}
#[test]
fn test_per_residue_plddt() {
let mut structure = PdbStructure::new();
structure.atoms.push(make_atom("A", 1, "ALA", 90.0));
structure.atoms.push(make_atom("A", 2, "GLY", 50.0));
let profile = structure.per_residue_plddt();
assert_eq!(profile.len(), 2);
let res1 = profile.iter().find(|r| r.residue_seq == 1).unwrap();
assert!((res1.plddt - 90.0).abs() < 0.01);
assert_eq!(res1.confidence_category, ConfidenceCategory::Confident);
let res2 = profile.iter().find(|r| r.residue_seq == 2).unwrap();
assert!((res2.plddt - 50.0).abs() < 0.01);
assert_eq!(res2.confidence_category, ConfidenceCategory::Low);
}
#[test]
fn test_low_confidence_regions() {
let mut structure = PdbStructure::new();
structure.atoms.push(make_atom("A", 1, "ALA", 90.0));
structure.atoms.push(make_atom("A", 2, "GLY", 50.0));
structure.atoms.push(make_atom("A", 3, "VAL", 30.0));
let low_conf = structure.low_confidence_regions(70.0);
assert_eq!(low_conf.len(), 2); }
#[test]
fn test_plddt_distribution() {
let mut structure = PdbStructure::new();
structure.atoms.push(make_atom("A", 1, "ALA", 95.0)); structure.atoms.push(make_atom("A", 2, "GLY", 80.0)); structure.atoms.push(make_atom("A", 3, "VAL", 60.0)); structure.atoms.push(make_atom("A", 4, "LEU", 40.0));
let (very_high, confident, low, very_low) = structure.plddt_distribution();
assert!((very_high - 0.25).abs() < 0.01);
assert!((confident - 0.25).abs() < 0.01);
assert!((low - 0.25).abs() < 0.01);
assert!((very_low - 0.25).abs() < 0.01);
}
#[test]
fn test_residue_plddt_methods() {
let res = ResiduePlddt {
chain_id: "A".to_string(),
residue_seq: 1,
residue_name: "ALA".to_string(),
ins_code: None,
plddt: 90.0,
plddt_min: 85.0,
plddt_max: 95.0,
atom_count: 5,
confidence_category: ConfidenceCategory::Confident,
};
assert!(res.is_confident());
assert!(!res.is_disordered());
}
}