use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use super::inference::{InferenceNet, MlffResult};
use super::symmetry_functions::{compute_aevs, SymmetryFunctionParams};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MlffConfig {
pub aev_params: SymmetryFunctionParams,
pub element_nets: HashMap<u8, InferenceNet>,
}
pub fn compute_mlff(
elements: &[u8],
positions: &[[f64; 3]],
config: &MlffConfig,
) -> Result<MlffResult, String> {
let n = elements.len();
if n == 0 {
return Err("Empty molecule".into());
}
if positions.len() != n {
return Err(format!(
"elements ({}) and positions ({}) length mismatch",
n,
positions.len()
));
}
for &z in elements {
if !config.element_nets.contains_key(&z) {
return Err(format!("No neural network parameters for element Z={}", z));
}
}
let aevs = compute_aevs(elements, positions, &config.aev_params);
let mut atomic_energies = Vec::with_capacity(n);
let mut total_energy = 0.0;
for i in 0..n {
let net = &config.element_nets[&elements[i]];
let aev_vec = aevs[i].to_vec();
let output = net.forward(&aev_vec);
let e_atom = output[0]; atomic_energies.push(e_atom);
total_energy += e_atom;
}
let forces = compute_mlff_forces(elements, positions, config, &aevs)?;
Ok(MlffResult {
energy: total_energy,
atomic_energies,
forces,
})
}
fn compute_mlff_forces(
elements: &[u8],
positions: &[[f64; 3]],
config: &MlffConfig,
_aevs: &[super::symmetry_functions::Aev],
) -> Result<Vec<[f64; 3]>, String> {
let n = elements.len();
let mut forces = vec![[0.0f64; 3]; n];
let step = 1e-4;
let mut pos_plus = positions.to_vec();
let mut pos_minus = positions.to_vec();
for i in 0..n {
for alpha in 0..3 {
pos_plus[i][alpha] = positions[i][alpha] + step;
pos_minus[i][alpha] = positions[i][alpha] - step;
let e_plus = compute_energy_only(elements, &pos_plus, config);
let e_minus = compute_energy_only(elements, &pos_minus, config);
forces[i][alpha] = -(e_plus - e_minus) / (2.0 * step);
pos_plus[i][alpha] = positions[i][alpha];
pos_minus[i][alpha] = positions[i][alpha];
}
}
Ok(forces)
}
fn compute_energy_only(elements: &[u8], positions: &[[f64; 3]], config: &MlffConfig) -> f64 {
let aevs = compute_aevs(elements, positions, &config.aev_params);
let mut total = 0.0;
for (i, aev) in aevs.iter().enumerate() {
let net = &config.element_nets[&elements[i]];
let out = net.forward(&aev.to_vec());
total += out[0];
}
total
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MlffDynamicsResult {
pub energy: f64,
pub gradient_flat: Vec<f64>,
}
pub fn compute_mlff_energy_and_gradient(
elements: &[u8],
positions_flat: &[f64],
config: &MlffConfig,
) -> Result<MlffDynamicsResult, String> {
let n = elements.len();
if positions_flat.len() != n * 3 {
return Err(format!(
"positions_flat length {} != 3 × {}",
positions_flat.len(),
n
));
}
let positions: Vec<[f64; 3]> = positions_flat
.chunks(3)
.map(|c| [c[0], c[1], c[2]])
.collect();
let result = compute_mlff(elements, &positions, config)?;
let mut gradient_flat = vec![0.0; n * 3];
for (i, f) in result.forces.iter().enumerate() {
gradient_flat[i * 3] = -f[0];
gradient_flat[i * 3 + 1] = -f[1];
gradient_flat[i * 3 + 2] = -f[2];
}
Ok(MlffDynamicsResult {
energy: result.energy,
gradient_flat,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ml::inference::{Activation, DenseLayer, InferenceNet};
use crate::ml::symmetry_functions::SymmetryFunctionParams;
use std::collections::HashMap;
fn dummy_config() -> MlffConfig {
let params = SymmetryFunctionParams::default();
let aev_len = params.radial_etas.len() * params.radial_shifts.len()
+ params.angular_etas.len() * params.angular_zetas.len() * params.angular_shifts.len();
let layer = DenseLayer {
weights: vec![0.01; aev_len],
bias: vec![0.0],
in_features: aev_len,
out_features: 1,
activation: Activation::Linear,
};
let net = InferenceNet::new(vec![layer]);
let mut element_nets = HashMap::new();
element_nets.insert(1u8, net);
MlffConfig {
aev_params: params,
element_nets,
}
}
#[test]
fn mlff_h2_returns_result() {
let config = dummy_config();
let result = compute_mlff(&[1, 1], &[[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]], &config);
assert!(result.is_ok(), "MLFF should succeed for H2");
let r = result.unwrap();
assert_eq!(r.forces.len(), 2);
assert!(r.energy.is_finite());
}
#[test]
fn mlff_gradient_has_correct_length() {
let config = dummy_config();
let positions = [0.0, 0.0, 0.0, 0.74, 0.0, 0.0];
let result = compute_mlff_energy_and_gradient(&[1, 1], &positions, &config);
assert!(result.is_ok());
let r = result.unwrap();
assert_eq!(r.gradient_flat.len(), 6);
}
#[test]
fn mlff_missing_element_net_returns_error() {
let config = dummy_config(); let result = compute_mlff(&[6, 6], &[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]], &config);
assert!(result.is_err(), "MLFF should fail for missing element nets");
}
}