#![allow(unused)]
use std::path::{Path, PathBuf};
use bio_files::{BondType, Mol2};
use candle_core::{DType, Device, Tensor};
use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder};
use rand::seq::SliceRandom;
use crate::partial_charge_inference::{
AtomVocab, MolGNN,
files::{GEOSTD_PATH, MODEL_PATH, VOCAB_PATH, find_mol2_paths},
save,
};
const N_EPOCHS: u8 = 50;
const EARLY_STOPPING_PATIENCE: u8 = 7;
const HIDDEN_DIM: usize = 128;
const LEARNING_RATE: f64 = 1e-3;
pub struct Batch {
pub elem_ids: Tensor,
pub coords: Tensor,
pub edge_index: Tensor,
pub type_ids: Tensor,
pub has_type: Tensor,
pub charges: Tensor,
pub num_atoms: usize,
pub edge_types: Tensor,
}
struct GeoStdMol2Dataset {
mol2_paths: Vec<PathBuf>,
vocabs: AtomVocab,
}
impl GeoStdMol2Dataset {
pub fn new(mol2_paths: &[PathBuf], vocabs: AtomVocab) -> Self {
Self {
mol2_paths: mol2_paths.to_vec(),
vocabs,
}
}
pub fn len(&self) -> usize {
self.mol2_paths.len()
}
pub fn get(&self, i_mol: usize, device: &Device) -> candle_core::Result<Batch> {
let mol = Mol2::load(&self.mol2_paths[i_mol])?;
let atoms = &mol.atoms;
let bonds = &mol.bonds;
let n = atoms.len();
let mut elem_ids = Vec::with_capacity(n);
let mut has_type = Vec::with_capacity(n);
let mut type_ids = Vec::with_capacity(n);
let mut charges = Vec::with_capacity(n);
let mut coords = Vec::with_capacity(n * 3);
let oov_elem_id = self.vocabs.el.len();
let oov_type_id = self.vocabs.atom_type.len();
for atom in atoms.iter() {
let el_id = self
.vocabs
.el
.get(&atom.element.to_letter())
.cloned()
.unwrap_or(oov_elem_id);
elem_ids.push(el_id as i64);
coords.push(atom.posit.x as f32);
coords.push(atom.posit.y as f32);
coords.push(atom.posit.z as f32);
let t_id = atom
.force_field_type
.as_ref()
.and_then(|ff| self.vocabs.atom_type.get(ff))
.cloned()
.unwrap_or(oov_type_id);
type_ids.push(t_id as i64);
has_type.push(1.0f32);
if let Some(pc) = atom.partial_charge {
charges.push(pc);
} else {
charges.push(0.0);
}
}
let mut edge_index_vec: Vec<i64> = Vec::new();
let mut edge_types_vec: Vec<i64> = Vec::new();
for bond in bonds.iter() {
let i = (bond.atom_0_sn - 1) as i64;
let j = (bond.atom_1_sn - 1) as i64;
let bt_id = match bond.bond_type {
BondType::Single => 0,
BondType::Double => 1,
BondType::Triple => 2,
BondType::Aromatic => 3,
_ => 4,
};
edge_index_vec.push(i);
edge_index_vec.push(j);
edge_types_vec.push(bt_id);
edge_index_vec.push(j);
edge_index_vec.push(i);
edge_types_vec.push(bt_id);
}
let elem_ids = Tensor::from_slice(&elem_ids, (n,), device)?;
let coords = Tensor::from_slice(&coords, (n, 3), device)?;
let type_ids = Tensor::from_slice(&type_ids, (n,), device)?;
let has_type = Tensor::from_slice(&has_type, (n,), device)?;
let charges = Tensor::from_slice(&charges, (n,), device)?;
let edge_index = if edge_index_vec.is_empty() {
Tensor::zeros((0, 2), DType::I64, device)?
} else {
let m = edge_index_vec.len() / 2;
Tensor::from_slice(&edge_index_vec, (m, 2), device)?
};
let edge_types = if edge_types_vec.is_empty() {
Tensor::zeros((0,), DType::I64, device)?
} else {
Tensor::from_slice(&edge_types_vec, (edge_types_vec.len(),), device)?
};
Ok(Batch {
elem_ids,
coords,
edge_index,
type_ids,
has_type,
charges,
num_atoms: n,
edge_types,
})
}
}
pub(crate) fn run_training() -> candle_core::Result<()> {
let device = Device::Cpu;
let mut rng = rand::rng();
println!("Training on GeoStd data with device: {device:?}");
let paths_mol2 = find_mol2_paths(Path::new(GEOSTD_PATH))?;
let mut sum = 0f64;
let mut sum_sq = 0f64;
let mut count = 0usize;
for path in &paths_mol2 {
let mol = Mol2::load(path)?;
for atom in mol.atoms.iter() {
if let Some(pc) = atom.partial_charge {
let q = pc as f64;
sum += q;
sum_sq += q * q;
count += 1;
}
}
}
let mean = (sum / count as f64) as f32;
let var = (sum_sq / count as f64) - (sum / count as f64).powi(2);
let std = var.max(1e-8).sqrt() as f32;
println!("Charge normalization: mean={mean}, std={std}");
let mut vocabs = AtomVocab::new(&paths_mol2)?;
vocabs.charge_mean = mean;
vocabs.charge_std = std;
let n_elems = vocabs.el.len();
let n_atom_types = vocabs.atom_type.len();
save(Path::new(VOCAB_PATH), &vocabs)?;
println!("Vocabs built and saved to {VOCAB_PATH}");
let dataset = GeoStdMol2Dataset::new(&paths_mol2, vocabs);
let mut varmap = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
let model = MolGNN::new(vb, n_elems, n_atom_types, HIDDEN_DIM)?;
let mut opt = AdamW::new(
varmap.all_vars(),
ParamsAdamW {
lr: LEARNING_RATE,
..Default::default()
},
)?;
let mut all_idxs: Vec<usize> = (0..dataset.len()).collect();
all_idxs.shuffle(&mut rng);
let split = (dataset.len() as f32 * 0.8) as usize;
let train_idxs = all_idxs[..split].to_vec();
let val_idxs = all_idxs[split..].to_vec();
let mut best_val_loss = f32::INFINITY;
let mut epochs_without_improve: u8 = 0;
for epoch in 0..N_EPOCHS {
println!("Running epoch {epoch} / {N_EPOCHS}...");
let mut train_order = train_idxs.clone();
train_order.shuffle(&mut rng);
let mut running_loss = 0.;
for i in train_order.iter() {
let batch = dataset.get(*i, &device)?;
let charges_pred = model.forward(
&batch.elem_ids,
&batch.type_ids,
&batch.coords,
&batch.edge_index,
&batch.edge_types,
)?;
let mean = dataset.vocabs.charge_mean;
let std = dataset.vocabs.charge_std;
let mean_t = Tensor::from_slice(&[mean], (1,), &device)?;
let std_t = Tensor::from_slice(&[std], (1,), &device)?;
let charges_norm = batch
.charges
.broadcast_sub(&mean_t)?
.broadcast_div(&std_t)?;
let diff = (charges_pred.clone() - &charges_norm)?;
let charge_loss = diff.sqr()?.mean_all()?;
let charges_pred_raw = charges_pred.broadcast_mul(&std_t)?.broadcast_add(&mean_t)?;
let sum_pred = charges_pred_raw.sum_all()?;
let sum_true = batch.charges.sum_all()?;
let sum_loss = (sum_pred - sum_true)?.sqr()?;
let loss = (&charge_loss + sum_loss * 0.01)?;
opt.backward_step(&loss)?;
running_loss += loss.to_scalar::<f32>()?;
}
let train_avg = running_loss / train_idxs.len() as f32;
let mut val_loss_sum = 0f32;
for i in val_idxs.iter() {
let batch = dataset.get(*i, &device)?;
let charges_pred = model.forward(
&batch.elem_ids,
&batch.type_ids,
&batch.coords,
&batch.edge_index,
&batch.edge_types,
)?;
let mean = dataset.vocabs.charge_mean;
let std = dataset.vocabs.charge_std;
let mean_t = Tensor::from_slice(&[mean], (1,), &device)?;
let std_t = Tensor::from_slice(&[std], (1,), &device)?;
let charges_norm = batch
.charges
.broadcast_sub(&mean_t)?
.broadcast_div(&std_t)?;
let diff = (charges_pred - &charges_norm)?;
let charge_loss = diff.sqr()?.mean_all()?;
let loss = charge_loss;
val_loss_sum += loss.to_scalar::<f32>()?;
}
let val_avg = val_loss_sum / val_idxs.len() as f32;
println!("Epoch {epoch} done. Train avg loss: {train_avg}, Val avg loss: {val_avg}");
if val_avg < best_val_loss {
best_val_loss = val_avg;
epochs_without_improve = 0;
varmap.save(MODEL_PATH)?; println!("New best val loss. Saved model to {MODEL_PATH}");
} else {
epochs_without_improve += 1;
if epochs_without_improve >= EARLY_STOPPING_PATIENCE {
println!(
"Early stopping at epoch {epoch} (no improvement for {EARLY_STOPPING_PATIENCE} epochs)"
);
break;
}
}
}
varmap.save(MODEL_PATH)?;
println!("Saved model to {MODEL_PATH}");
Ok(())
}