use anyhow::Result;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};
use crate::config::DPModelConfig;
use crate::descriptor::{build_se_a_descriptor, SeADescriptor};
use crate::fitting::{build_ener_fitting, EnerFitting};
pub struct DPEnergyGraph {
pub energy: NodeId,
pub atom_energy: NodeId,
pub descriptor: SeADescriptor,
pub inputs: DPEnergyInputs,
}
#[derive(Debug, Clone, Copy)]
pub struct DPEnergyInputs {
pub env_mat: NodeId,
pub atype: NodeId,
}
pub fn build_dp_energy_graph(
cfg: &DPModelConfig,
nf: usize,
nloc: usize,
) -> Result<(Graph, DPEnergyGraph)> {
let mut g = Graph::new("dp_energy");
let nnei = cfg.descriptor.nnei();
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let descriptor = build_se_a_descriptor(
&mut g,
&cfg.descriptor,
env_mat,
atype,
nf,
nloc,
crate::descriptor::SeAExtraInputs::default(),
)?;
assert_eq!(
descriptor.dim_out, cfg.fitting_net.dim_descrpt,
"descriptor.dim_out ({}) must match fitting_net.dim_descrpt ({})",
descriptor.dim_out, cfg.fitting_net.dim_descrpt
);
let EnerFitting { atom_energy } = build_ener_fitting(
&mut g,
&cfg.fitting_net,
nf,
nloc,
descriptor.descriptor,
atype,
None,
None,
None,
)?;
let dim_out = cfg.fitting_net.dim_out;
let atom_e_t = {
use rlx_ir::infer::GraphExt;
g.transpose_(atom_energy, vec![0, 2, 1]) };
let energy = g.reduce(
atom_e_t,
ReduceOp::Sum,
vec![2],
false,
Shape::new(&[nf, dim_out], DType::F32),
);
Ok((
g,
DPEnergyGraph {
energy,
atom_energy,
descriptor,
inputs: DPEnergyInputs { env_mat, atype },
},
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::DPModelConfig;
#[test]
fn dp_energy_graph_builds_single_type() {
let cfg = DPModelConfig::new(6.0, 0.5, vec![46]);
let (g, out) = build_dp_energy_graph(&cfg, 1, 8).expect("build");
assert!(g.len() > 10);
assert_eq!(out.descriptor.dim_out, cfg.descriptor.dim_out());
}
#[test]
fn dp_energy_graph_builds_multi_type() {
let cfg = DPModelConfig::new(6.0, 0.5, vec![46, 92]);
let (g, out) = build_dp_energy_graph(&cfg, 2, 192).expect("build");
assert!(g.len() > 20);
assert_eq!(out.descriptor.dim_out, cfg.descriptor.dim_out());
}
#[test]
fn dp_energy_graph_type_two_side() {
let mut cfg = DPModelConfig::new(6.0, 0.5, vec![46, 92]);
cfg.descriptor.type_one_side = false;
let (g, out) = build_dp_energy_graph(&cfg, 1, 16).expect("build");
assert!(g.len() > 40);
assert_eq!(out.descriptor.dim_out, cfg.descriptor.dim_out());
}
}