use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};
use crate::descriptor_dpa1::{
build_dpa1_descriptor, DPA1Config, DPA1Inputs,
};
use crate::descriptor_repformers::{
build_repformers, RepformersConfig, RepformersInputs,
};
use crate::descriptor_se_t_tebd::TebdInputMode;
use crate::nn::{embedding_mlp, ActivationKind, MlpSpec};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DPA2Config {
pub repinit: RepinitArgs,
pub repformers: RepformersConfig,
pub ntypes: usize,
#[serde(default = "default_tebd_dim")]
pub tebd_dim: usize,
#[serde(default = "default_true")]
pub concat_output_tebd: bool,
#[serde(default)]
pub add_tebd_to_repinit_out: bool,
#[serde(default)]
pub repinit_three_body: Option<RepinitArgs>,
#[serde(default = "default_activation")]
pub activation_function: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RepinitArgs {
pub rcut: f64,
pub rcut_smth: f64,
pub nsel: usize,
#[serde(default = "default_neuron")]
pub neuron: Vec<usize>,
#[serde(default = "default_axis_neuron")]
pub axis_neuron: usize,
#[serde(default)]
pub tebd_input_mode: TebdInputMode,
#[serde(default)]
pub type_one_side: bool,
#[serde(default)]
pub resnet_dt: bool,
#[serde(default)]
pub use_three_body: bool,
}
fn default_neuron() -> Vec<usize> {
vec![25, 50, 100]
}
fn default_axis_neuron() -> usize {
16
}
fn default_tebd_dim() -> usize {
8
}
fn default_true() -> bool {
true
}
fn default_activation() -> String {
"tanh".into()
}
pub struct DPA2Inputs {
pub env_mat_raw_repinit: NodeId,
pub env_mat_raw_repformer: NodeId,
pub atype_loc: NodeId,
pub nei_atype: NodeId,
pub nlist_repinit: NodeId,
pub nlist_mask_repinit: NodeId,
pub sw_repinit: NodeId,
pub nlist_repformer: NodeId,
pub nlist_mask_repformer: NodeId,
pub sw_repformer: NodeId,
pub type_embedding: NodeId,
pub g1_ext_tebd: NodeId,
pub mapping: NodeId,
pub env_mat_raw_three_body: Option<NodeId>,
pub nei_atype_three_body: Option<NodeId>,
pub sw_three_body: Option<NodeId>,
}
pub struct DPA2Descriptor {
pub descriptor: NodeId,
pub gr: NodeId,
pub debug_repinit: Option<NodeId>,
pub debug_g1_post_xform: Option<NodeId>,
pub dim_out: usize,
}
pub fn build_dpa2_descriptor(
g: &mut Graph,
cfg: &DPA2Config,
inputs: DPA2Inputs,
nf: usize,
nloc: usize,
nall: usize,
) -> Result<DPA2Descriptor> {
let activation = ActivationKind::parse(&cfg.activation_function)?;
let repinit_cfg = repinit_to_dpa1(cfg, &cfg.repinit);
let dpa1_inputs = DPA1Inputs {
env_mat_raw: inputs.env_mat_raw_repinit,
atype_loc: inputs.atype_loc,
nei_atype: inputs.nei_atype,
nlist_mask: inputs.nlist_mask_repinit,
sw: inputs.sw_repinit,
type_embedding: inputs.type_embedding,
};
let repinit_out = build_dpa1_descriptor(g, &repinit_cfg, dpa1_inputs, nf, nloc)?;
let mut g1 = repinit_out.descriptor; let dbg_repinit = Some(g1);
if cfg.repinit.use_three_body {
use crate::descriptor_se_t_tebd::{
build_se_t_tebd_descriptor, SeTTebdConfig, SeTTebdInputs,
};
let three_args = cfg.repinit_three_body.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"use_three_body=True requires cfg.repinit_three_body to be set"
)
})?;
let env_three = inputs.env_mat_raw_three_body.ok_or_else(|| {
anyhow::anyhow!("use_three_body=True requires env_mat_raw_three_body")
})?;
let nei_three = inputs.nei_atype_three_body.ok_or_else(|| {
anyhow::anyhow!("use_three_body=True requires nei_atype_three_body")
})?;
let sw_three = inputs.sw_three_body.ok_or_else(|| {
anyhow::anyhow!("use_three_body=True requires sw_three_body")
})?;
let three_cfg = SeTTebdConfig {
rcut: three_args.rcut,
rcut_smth: three_args.rcut_smth,
sel: vec![three_args.nsel],
ntypes: cfg.ntypes,
neuron: three_args.neuron.clone(),
tebd_dim: cfg.tebd_dim,
tebd_input_mode: three_args.tebd_input_mode,
resnet_dt: three_args.resnet_dt,
smooth: true,
activation_function: cfg.activation_function.clone(),
param_prefix: "repinit_three_body".into(),
};
let three_inputs = SeTTebdInputs {
env_mat_raw: env_three,
atype_loc: inputs.atype_loc,
nei_atype: nei_three,
type_embedding: inputs.type_embedding,
sw: Some(sw_three),
exclude_mask: None,
};
let three_out = build_se_t_tebd_descriptor(g, &three_cfg, three_inputs, nf, nloc)?;
let main_dim = match g.shape(g1).dim(2) {
rlx_ir::Dim::Static(n) => n,
_ => return Err(anyhow::anyhow!("dpa2: g1 last dim must be static")),
};
let cat_shape = Shape::new(
&[nf, nloc, main_dim + three_out.dim_out],
DType::F32,
);
g1 = g.concat(vec![g1, three_out.descriptor], 2, cat_shape);
}
let ng1 = cfg.repformers.g1_dim;
let g1_in_dim = match g.shape(g1).dim(2) {
rlx_ir::Dim::Static(n) => n,
_ => return Err(anyhow::anyhow!("dpa2: g1 last dim must be static")),
};
let mut g1 = if g1_in_dim == ng1 {
g1
} else {
let xform_w = g.param(
"g1_shape_tranform.w",
Shape::new(&[g1_in_dim, ng1], DType::F32),
);
g.mm(g1, xform_w)
};
if cfg.add_tebd_to_repinit_out {
let centre = g.gather_(inputs.type_embedding, inputs.atype_loc, 0);
let trans = MlpSpec {
param_prefix: "tebd_transform",
in_dim: cfg.tebd_dim,
neuron: &[ng1],
activation,
resnet_dt: false,
};
let tebd_proj = embedding_mlp(g, &trans, centre);
g1 = g.add(g1, tebd_proj);
}
let dbg_g1_xform = Some(g1);
let g1_ext_padded = if nloc == nall {
g1
} else {
let _ = inputs.g1_ext_tebd;
g.gather_(g1, inputs.mapping, 1)
};
let rf_inputs = RepformersInputs {
g1_ext: g1_ext_padded,
env_mat_raw: inputs.env_mat_raw_repformer,
nlist: inputs.nlist_repformer,
nlist_mask: inputs.nlist_mask_repformer,
sw: inputs.sw_repformer,
atype_loc: inputs.atype_loc,
mapping: if nloc == nall { None } else { Some(inputs.mapping) },
};
let rf_out = build_repformers(g, &cfg.repformers, rf_inputs, nf, nloc, nall)?;
let mut out_g1 = rf_out.g1;
if cfg.concat_output_tebd {
let centre = g.gather_(inputs.type_embedding, inputs.atype_loc, 0); let cat_shape =
Shape::new(&[nf, nloc, ng1 + cfg.tebd_dim], DType::F32);
out_g1 = g.concat(vec![out_g1, centre], 2, cat_shape);
}
let dim_out = if cfg.concat_output_tebd {
ng1 + cfg.tebd_dim
} else {
ng1
};
Ok(DPA2Descriptor {
descriptor: out_g1,
gr: rf_out.rot_mat,
debug_repinit: dbg_repinit,
debug_g1_post_xform: dbg_g1_xform,
dim_out,
})
}
fn repinit_to_dpa1(top: &DPA2Config, ri: &RepinitArgs) -> DPA1Config {
DPA1Config {
rcut: ri.rcut,
rcut_smth: ri.rcut_smth,
sel: vec![ri.nsel],
ntypes: top.ntypes,
neuron: ri.neuron.clone(),
axis_neuron: ri.axis_neuron,
tebd_dim: top.tebd_dim,
tebd_input_mode: ri.tebd_input_mode,
resnet_dt: ri.resnet_dt,
attn: 16,
attn_layer: 0, attn_dotr: false,
num_heads: 1,
normalize: false,
scaling_factor: 1.0,
ln_eps: 1e-5,
smooth: true,
type_one_side: ri.type_one_side,
activation_function: top.activation_function.clone(),
concat_output_tebd: false,
}
}