use super::descriptors::MolecularDescriptors;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MlPropertyResult {
pub logp: f64,
pub molar_refractivity: f64,
pub log_solubility: f64,
pub lipinski: LipinskiResult,
pub druglikeness: f64,
pub uncertainty: PredictionUncertainty,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredictionUncertainty {
pub confidence: f64,
pub logp_std: f64,
pub solubility_std: f64,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LipinskiResult {
pub mw_ok: bool,
pub logp_ok: bool,
pub hbd_ok: bool,
pub hba_ok: bool,
pub violations: u8,
pub passes: bool,
}
fn predict_logp(desc: &MolecularDescriptors) -> f64 {
let base = 0.120 * desc.n_heavy_atoms as f64;
let h_correction = -0.230 * desc.n_hbd as f64; let ring_correction = 0.150 * desc.n_rings as f64;
let aromatic_correction = 0.080 * desc.n_aromatic as f64;
let polar_correction = -0.310 * desc.n_hba as f64;
let sp3_correction = -0.180 * desc.fsp3;
let mw_term = 0.005 * (desc.molecular_weight - 100.0);
base + h_correction
+ ring_correction
+ aromatic_correction
+ polar_correction
+ sp3_correction
+ mw_term
}
fn predict_molar_refractivity(desc: &MolecularDescriptors) -> f64 {
let base = 2.536 * desc.sum_polarizability;
let ring_correction = 1.20 * desc.n_rings as f64;
let aromatic = 0.80 * desc.n_aromatic as f64;
base + ring_correction + aromatic
}
fn predict_solubility(desc: &MolecularDescriptors, logp: f64) -> f64 {
let frac_aromatic = if desc.n_heavy_atoms > 0 {
desc.n_aromatic as f64 / desc.n_heavy_atoms as f64
} else {
0.0
};
0.16 - 0.63 * logp - 0.0062 * desc.molecular_weight + 0.066 * desc.n_rotatable_bonds as f64
- 0.74 * frac_aromatic
}
fn druglikeness_score(desc: &MolecularDescriptors, logp: f64) -> f64 {
let mut score = 1.0;
if desc.molecular_weight > 500.0 {
score -= 0.2 * ((desc.molecular_weight - 500.0) / 200.0).min(1.0);
}
if logp > 5.0 {
score -= 0.2 * ((logp - 5.0) / 3.0).min(1.0);
} else if logp < -2.0 {
score -= 0.15;
}
if desc.n_hbd > 5 {
score -= 0.15;
}
if desc.n_hba > 10 {
score -= 0.15;
}
if desc.n_rotatable_bonds > 10 {
score -= 0.1 * ((desc.n_rotatable_bonds as f64 - 10.0) / 5.0).min(1.0);
}
score += 0.05 * desc.fsp3;
score.clamp(0.0, 1.0)
}
fn estimate_uncertainty(desc: &MolecularDescriptors) -> PredictionUncertainty {
let mut confidence: f64 = 1.0;
let mut warnings = Vec::new();
if desc.molecular_weight > 900.0 {
confidence -= 0.3;
warnings.push("MW > 900 — outside typical training domain".to_string());
} else if desc.molecular_weight > 600.0 {
confidence -= 0.1;
}
if desc.n_heavy_atoms > 70 {
confidence -= 0.2;
warnings.push("Large molecule (>70 heavy atoms) — reduced accuracy".to_string());
}
if desc.n_rings > 8 {
confidence -= 0.15;
warnings.push("Many rings — polycyclic compounds have higher uncertainty".to_string());
}
if desc.n_rotatable_bonds > 15 {
confidence -= 0.1;
}
let logp_std = 0.6 + 0.002 * desc.molecular_weight.max(0.0); let sol_std = 0.8 + 0.003 * desc.molecular_weight.max(0.0);
PredictionUncertainty {
confidence: confidence.clamp(0.1, 1.0),
logp_std: logp_std / confidence.max(0.3),
solubility_std: sol_std / confidence.max(0.3),
warnings,
}
}
pub fn predict_properties(desc: &MolecularDescriptors) -> MlPropertyResult {
let logp = predict_logp(desc);
let mr = predict_molar_refractivity(desc);
let log_s = predict_solubility(desc, logp);
let mw_ok = desc.molecular_weight <= 500.0;
let logp_ok = logp <= 5.0;
let hbd_ok = desc.n_hbd <= 5;
let hba_ok = desc.n_hba <= 10;
let violations = [!mw_ok, !logp_ok, !hbd_ok, !hba_ok]
.iter()
.filter(|&&v| v)
.count() as u8;
let lipinski = LipinskiResult {
mw_ok,
logp_ok,
hbd_ok,
hba_ok,
violations,
passes: violations <= 1,
};
let druglikeness = druglikeness_score(desc, logp);
let uncertainty = estimate_uncertainty(desc);
MlPropertyResult {
logp,
molar_refractivity: mr,
log_solubility: log_s,
lipinski,
druglikeness,
uncertainty,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ml::descriptors::compute_descriptors;
#[test]
fn test_predict_water() {
let elements = [8u8, 1, 1];
let bonds = [(0, 1, 1u8), (0, 2, 1)];
let desc = compute_descriptors(&elements, &bonds, &[], &[]);
let result = predict_properties(&desc);
assert!(
result.logp < 1.0,
"Water logP should be low: {}",
result.logp
);
assert!(result.lipinski.passes, "Water should pass Lipinski");
}
#[test]
fn test_lipinski_violations() {
let desc = MolecularDescriptors {
molecular_weight: 800.0,
n_heavy_atoms: 60,
n_hydrogens: 20,
n_bonds: 80,
n_rotatable_bonds: 15,
n_hbd: 8,
n_hba: 15,
fsp3: 0.1,
total_abs_charge: 5.0,
max_charge: 0.5,
min_charge: -0.5,
wiener_index: 5000.0,
n_rings: 5,
n_aromatic: 12,
balaban_j: 2.0,
sum_electronegativity: 150.0,
sum_polarizability: 80.0,
};
let result = predict_properties(&desc);
assert!(
result.lipinski.violations >= 2,
"Should have multiple violations"
);
assert!(!result.lipinski.passes, "Should fail Lipinski");
}
#[test]
fn test_druglikeness_range() {
let elements = [6u8, 6, 8, 1, 1, 1, 1, 1, 1];
let bonds = [
(0, 1, 1u8),
(1, 2, 1),
(0, 3, 1),
(0, 4, 1),
(0, 5, 1),
(1, 6, 1),
(1, 7, 1),
(2, 8, 1),
];
let desc = compute_descriptors(&elements, &bonds, &[], &[]);
let result = predict_properties(&desc);
assert!(result.druglikeness >= 0.0 && result.druglikeness <= 1.0);
}
}