use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};
use crate::descriptor_repflows::{
build_repflows, RepflowsConfig, RepflowsInputs,
};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DPA3Config {
pub repflows: RepflowsConfig,
pub ntypes: usize,
#[serde(default = "default_tebd_dim")]
pub tebd_dim: usize,
#[serde(default = "default_true")]
pub concat_output_tebd: bool,
#[serde(default)]
pub use_loc_mapping: bool,
#[serde(default)]
pub add_chg_spin_ebd: bool,
#[serde(default = "default_chg_spin_activation")]
pub activation_function: String,
}
fn default_chg_spin_activation() -> String {
"silu".into()
}
fn default_tebd_dim() -> usize {
8
}
fn default_true() -> bool {
true
}
pub struct DPA3Inputs {
pub env_mat_raw: NodeId,
pub atype_loc: NodeId,
pub nlist: NodeId,
pub nlist_mask: NodeId,
pub sw: NodeId,
pub angle_input: NodeId,
pub angle_mask: NodeId,
pub type_embedding: NodeId,
pub node_ebd_ext: NodeId,
pub charge_spin: Option<NodeId>,
pub chg_table: Option<NodeId>,
pub spin_table: Option<NodeId>,
}
pub struct DPA3Descriptor {
pub descriptor: NodeId,
pub gr: NodeId,
pub dim_out: usize,
}
pub fn build_dpa3_descriptor(
g: &mut Graph,
cfg: &DPA3Config,
inputs: DPA3Inputs,
nf: usize,
nloc: usize,
nall: usize,
) -> Result<DPA3Descriptor> {
use crate::nn::{scalar_const, ActivationKind, apply_activation};
let node_ebd_ext_final = if cfg.add_chg_spin_ebd {
let activation = ActivationKind::parse(&cfg.activation_function)?;
let cs = inputs.charge_spin.ok_or_else(|| {
anyhow::anyhow!("add_chg_spin_ebd=True requires charge_spin input")
})?;
let chg_table = inputs.chg_table.ok_or_else(|| {
anyhow::anyhow!("add_chg_spin_ebd=True requires chg_table input")
})?;
let spin_table = inputs.spin_table.ok_or_else(|| {
anyhow::anyhow!("add_chg_spin_ebd=True requires spin_table input")
})?;
let charge_f = g.narrow_(cs, 1, 0, 1); let charge_f_2d = g.reshape(
charge_f, vec![nf as i64, 1],
Shape::new(&[nf, 1], DType::F32),
);
let one_hundred = scalar_const(g, 100.0);
let charge_shifted = g.add(charge_f_2d, one_hundred);
let charge_idx = g.reshape(
charge_shifted, vec![nf as i64],
Shape::new(&[nf], DType::F32),
);
let spin_f = g.narrow_(cs, 1, 1, 1); let spin_idx = g.reshape(
spin_f, vec![nf as i64],
Shape::new(&[nf], DType::F32),
);
let chg_ebd = g.gather_(chg_table, charge_idx, 0);
let spin_ebd = g.gather_(spin_table, spin_idx, 0);
let cs_cat = g.concat(
vec![chg_ebd, spin_ebd], 1,
Shape::new(&[nf, 2 * cfg.tebd_dim], DType::F32),
);
let mix_w = g.param(
"mix_cs_mlp.w",
Shape::new(&[2 * cfg.tebd_dim, cfg.tebd_dim], DType::F32),
);
let mix_b = g.param(
"mix_cs_mlp.b",
Shape::new(&[cfg.tebd_dim], DType::F32),
);
let pre = g.mm(cs_cat, mix_w);
let pre = g.add(pre, mix_b);
let sys_cs = apply_activation(g, activation, pre); let node_shape = g.shape(inputs.node_ebd_ext).clone();
let atom_dim = match node_shape.dim(1) {
rlx_ir::Dim::Static(d) => d,
_ => return Err(anyhow::anyhow!("dpa3: node_ebd_ext atom dim must be static")),
};
let sys_cs_3d = g.reshape(
sys_cs, vec![nf as i64, 1, cfg.tebd_dim as i64],
Shape::new(&[nf, 1, cfg.tebd_dim], DType::F32),
);
let mut tiles = Vec::with_capacity(atom_dim);
for _ in 0..atom_dim {
tiles.push(sys_cs_3d);
}
let sys_cs_broadcast = g.concat(
tiles, 1,
Shape::new(&[nf, atom_dim, cfg.tebd_dim], DType::F32),
);
g.add(inputs.node_ebd_ext, sys_cs_broadcast)
} else {
inputs.node_ebd_ext
};
let rf_inputs = RepflowsInputs {
node_ebd_ext: node_ebd_ext_final,
env_mat_raw: inputs.env_mat_raw,
nlist: inputs.nlist,
nlist_mask: inputs.nlist_mask,
sw: inputs.sw,
angle_input: inputs.angle_input,
angle_mask: inputs.angle_mask,
a_sw: None,
atype_loc: inputs.atype_loc,
};
let rf_out = build_repflows(g, &cfg.repflows, rf_inputs, nf, nloc, nall)?;
let mut node_ebd = rf_out.node_ebd;
let n_dim = cfg.repflows.n_dim;
if cfg.concat_output_tebd {
let centre = g.gather_(inputs.type_embedding, inputs.atype_loc, 0); let cat_shape = Shape::new(&[nf, nloc, n_dim + cfg.tebd_dim], DType::F32);
node_ebd = g.concat(vec![node_ebd, centre], 2, cat_shape);
}
let _ = inputs.charge_spin;
let _ = cfg.add_chg_spin_ebd;
let _ = cfg.use_loc_mapping;
let dim_out = if cfg.concat_output_tebd {
n_dim + cfg.tebd_dim
} else {
n_dim
};
Ok(DPA3Descriptor {
descriptor: node_ebd,
gr: rf_out.rot_mat,
dim_out,
})
}