use anyhow::Result;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};
use crate::config::{DPModelConfig, EnerFittingConfig, SeAConfig};
use crate::descriptor::{build_se_a_descriptor, SeADescriptor, SeAExtraInputs};
use crate::fitting::{
build_dipole_fitting, build_invar_fitting, build_polar_fitting, DipoleFittingConfig,
FittingInputs, PolarFittingConfig,
};
pub struct DPInputs {
pub env_mat: NodeId,
pub atype: NodeId,
}
fn build_descriptor(
g: &mut Graph,
descriptor_cfg: &SeAConfig,
nf: usize,
nloc: usize,
) -> Result<(SeADescriptor, DPInputs)> {
let nnei = descriptor_cfg.nnei();
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let descriptor = build_se_a_descriptor(
g,
descriptor_cfg,
env_mat,
atype,
nf,
nloc,
SeAExtraInputs::default(),
)?;
Ok((descriptor, DPInputs { env_mat, atype }))
}
pub struct DPDipoleGraph {
pub dipole: NodeId,
pub global_dipole: NodeId,
pub inputs: DPInputs,
}
pub fn build_dp_dipole_graph(
descriptor_cfg: &SeAConfig,
fitting_cfg: &DipoleFittingConfig,
nf: usize,
nloc: usize,
) -> Result<(Graph, DPDipoleGraph)> {
let mut g = Graph::new("dp_dipole");
let (descriptor, inputs) = build_descriptor(&mut g, descriptor_cfg, nf, nloc)?;
let gr_4d = reshape_gr_for_fitting(&mut g, descriptor.gr, nf, nloc, descriptor_cfg);
let fit_inputs = FittingInputs {
descriptor: descriptor.descriptor,
atype_loc: inputs.atype,
fparam: None,
aparam: None,
case_embd: None,
exclude_mask: None,
gr: Some(gr_4d),
};
let out = build_dipole_fitting(&mut g, fitting_cfg, fit_inputs, nf, nloc)?;
let global = g.reduce(
out.dipole,
ReduceOp::Sum,
vec![1],
false,
Shape::new(&[nf, 3], DType::F32),
);
Ok((
g,
DPDipoleGraph {
dipole: out.dipole,
global_dipole: global,
inputs,
},
))
}
pub struct DPPolarGraph {
pub polar: NodeId,
pub global_polar: NodeId,
pub inputs: DPInputs,
}
pub fn build_dp_polar_graph(
descriptor_cfg: &SeAConfig,
fitting_cfg: &PolarFittingConfig,
nf: usize,
nloc: usize,
) -> Result<(Graph, DPPolarGraph)> {
let mut g = Graph::new("dp_polar");
let (descriptor, inputs) = build_descriptor(&mut g, descriptor_cfg, nf, nloc)?;
let gr_4d = reshape_gr_for_fitting(&mut g, descriptor.gr, nf, nloc, descriptor_cfg);
let fit_inputs = FittingInputs {
descriptor: descriptor.descriptor,
atype_loc: inputs.atype,
fparam: None,
aparam: None,
case_embd: None,
exclude_mask: None,
gr: Some(gr_4d),
};
let out = build_polar_fitting(&mut g, fitting_cfg, fit_inputs, nf, nloc)?;
let global = g.reduce(
out.polar,
ReduceOp::Sum,
vec![1],
false,
Shape::new(&[nf, 3, 3], DType::F32),
);
Ok((
g,
DPPolarGraph {
polar: out.polar,
global_polar: global,
inputs,
},
))
}
pub struct DPDosGraph {
pub atom_dos: NodeId,
pub dos: NodeId,
pub inputs: DPInputs,
}
pub fn build_dp_dos_graph(
cfg: &DPModelConfig,
nf: usize,
nloc: usize,
) -> Result<(Graph, DPDosGraph)> {
let mut g = Graph::new("dp_dos");
let (descriptor, inputs) = build_descriptor(&mut g, &cfg.descriptor, nf, nloc)?;
let fit_inputs = FittingInputs {
descriptor: descriptor.descriptor,
atype_loc: inputs.atype,
fparam: None,
aparam: None,
case_embd: None,
exclude_mask: None,
gr: None,
};
let invar = build_invar_fitting(
&mut g,
&cfg.fitting_net,
fit_inputs,
nf,
nloc,
"fitting",
"fitting.bias_atom_e",
)?;
let dos = g.reduce(
invar.atom_output,
ReduceOp::Sum,
vec![1],
false,
Shape::new(&[nf, cfg.fitting_net.dim_out], DType::F32),
);
Ok((
g,
DPDosGraph {
atom_dos: invar.atom_output,
dos,
inputs,
},
))
}
pub struct DPPropertyGraph {
pub atom_property: NodeId,
pub property: NodeId,
pub inputs: DPInputs,
}
pub fn build_dp_property_graph(
cfg: &DPModelConfig,
bias_param_name: &str,
nf: usize,
nloc: usize,
) -> Result<(Graph, DPPropertyGraph)> {
let mut g = Graph::new("dp_property");
let (descriptor, inputs) = build_descriptor(&mut g, &cfg.descriptor, nf, nloc)?;
let fit_inputs = FittingInputs {
descriptor: descriptor.descriptor,
atype_loc: inputs.atype,
fparam: None,
aparam: None,
case_embd: None,
exclude_mask: None,
gr: None,
};
let invar = build_invar_fitting(
&mut g,
&cfg.fitting_net,
fit_inputs,
nf,
nloc,
"fitting",
bias_param_name,
)?;
let global = g.reduce(
invar.atom_output,
ReduceOp::Sum,
vec![1],
false,
Shape::new(&[nf, cfg.fitting_net.dim_out], DType::F32),
);
Ok((
g,
DPPropertyGraph {
atom_property: invar.atom_output,
property: global,
inputs,
},
))
}
fn reshape_gr_for_fitting(
g: &mut Graph,
gr: NodeId,
nf: usize,
nloc: usize,
cfg: &SeAConfig,
) -> NodeId {
let ng = cfg.ng();
let out_shape = Shape::new(&[nf, nloc, ng, 3], DType::F32);
g.reshape(
gr,
vec![nf as i64, nloc as i64, ng as i64, 3],
out_shape,
)
}
pub fn build_dp_energy_graph(
cfg: &DPModelConfig,
nf: usize,
nloc: usize,
) -> Result<(Graph, crate::model::DPEnergyGraph)> {
crate::model::build_dp_energy_graph(cfg, nf, nloc)
}
pub fn ener_cfg_from(descriptor: &SeAConfig, neuron: Vec<usize>, dim_out: usize) -> EnerFittingConfig {
EnerFittingConfig {
ntypes: descriptor.ntypes(),
dim_descrpt: descriptor.dim_out(),
dim_out,
neuron,
resnet_dt: true,
numb_fparam: 0,
numb_aparam: 0,
dim_case_embd: 0,
activation_function: "tanh".into(),
mixed_types: true,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fitting::{DipoleFittingConfig, PolarFittingConfig};
fn descrpt() -> SeAConfig {
SeAConfig {
rcut: 6.0,
rcut_smth: 0.5,
sel: vec![46, 92],
neuron: vec![16, 32, 64],
axis_neuron: 4,
resnet_dt: false,
type_one_side: true,
activation_function: "tanh".into(),
type_map: None,
}
}
#[test]
fn dipole_graph_builds() {
let d = descrpt();
let mw = d.ng();
let cfg = DipoleFittingConfig {
base: ener_cfg_from(&d, vec![16], 1),
embedding_width: mw,
};
let (g, _) = build_dp_dipole_graph(&d, &cfg, 1, 4).expect("build");
assert!(g.len() > 20);
}
#[test]
fn polar_graph_builds() {
let d = descrpt();
let mw = d.ng();
let cfg = PolarFittingConfig {
base: ener_cfg_from(&d, vec![16], 1),
embedding_width: mw,
fit_diag: false,
shift_diag: true,
};
let (g, _) = build_dp_polar_graph(&d, &cfg, 1, 4).expect("build");
assert!(g.len() > 25);
}
#[test]
fn dos_graph_builds() {
let d = descrpt();
let fit = EnerFittingConfig {
ntypes: d.ntypes(),
dim_descrpt: d.dim_out(),
dim_out: 32, neuron: vec![16, 16],
resnet_dt: true,
numb_fparam: 0,
numb_aparam: 0,
dim_case_embd: 0,
activation_function: "tanh".into(),
mixed_types: true,
};
let cfg = DPModelConfig {
descriptor: d,
fitting_net: fit,
type_map: None,
};
let (g, out) = build_dp_dos_graph(&cfg, 1, 4).expect("build");
assert!(g.len() > 20);
let s = g.shape(out.dos);
assert_eq!(s.dim(1), rlx_ir::Dim::Static(32));
}
}