use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::{Activation, BinaryOp, ReduceOp};
use rlx_ir::{DType, Dim, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};
use crate::descriptor_se_t_tebd::TebdInputMode;
use crate::nn::{
apply_activation, dense_layer, embedding_mlp, scalar_const, ActivationKind,
DenseLayerSpec, MlpSpec,
};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DPA1Config {
pub rcut: f64,
pub rcut_smth: f64,
pub sel: Vec<usize>,
pub ntypes: usize,
#[serde(default = "default_neuron")]
pub neuron: Vec<usize>,
#[serde(default = "default_axis_neuron")]
pub axis_neuron: usize,
#[serde(default = "default_tebd_dim")]
pub tebd_dim: usize,
#[serde(default)]
pub tebd_input_mode: TebdInputMode,
#[serde(default)]
pub resnet_dt: bool,
#[serde(default = "default_attn")]
pub attn: usize,
#[serde(default = "default_attn_layer")]
pub attn_layer: usize,
#[serde(default = "default_true")]
pub attn_dotr: bool,
#[serde(default = "default_num_heads")]
pub num_heads: usize,
#[serde(default = "default_true")]
pub normalize: bool,
#[serde(default)]
pub scaling_factor: f32,
#[serde(default = "default_ln_eps")]
pub ln_eps: f32,
#[serde(default = "default_true")]
pub smooth: bool,
#[serde(default)]
pub type_one_side: bool,
#[serde(default = "default_activation")]
pub activation_function: String,
#[serde(default = "default_true")]
pub concat_output_tebd: bool,
}
fn default_neuron() -> Vec<usize> {
vec![25, 50, 100]
}
fn default_axis_neuron() -> usize {
8
}
fn default_tebd_dim() -> usize {
8
}
fn default_attn() -> usize {
128
}
fn default_attn_layer() -> usize {
2
}
fn default_num_heads() -> usize {
1
}
fn default_ln_eps() -> f32 {
1e-5
}
fn default_true() -> bool {
true
}
fn default_activation() -> String {
"tanh".into()
}
impl DPA1Config {
pub fn nnei(&self) -> usize {
self.sel.iter().sum()
}
pub fn ng(&self) -> usize {
*self.neuron.last().expect("dpa1: empty neuron list")
}
pub fn dim_out(&self) -> usize {
let core = self.ng() * self.axis_neuron;
if self.concat_output_tebd {
core + self.tebd_dim
} else {
core
}
}
pub fn head_dim(&self) -> usize {
self.attn / self.num_heads
}
}
pub struct DPA1Inputs {
pub env_mat_raw: NodeId,
pub atype_loc: NodeId,
pub nei_atype: NodeId,
pub nlist_mask: NodeId,
pub sw: NodeId,
pub type_embedding: NodeId,
}
pub struct DPA1Descriptor {
pub descriptor: NodeId,
pub gr: NodeId,
pub gg: NodeId,
pub debug_attn_query: Option<NodeId>,
pub debug_attn_aw_post: Option<NodeId>,
pub debug_attn_o_pre_proj: Option<NodeId>,
pub debug_input_r: Option<NodeId>,
pub debug_angular_4d: Option<NodeId>,
pub dim_out: usize,
}
pub fn build_dpa1_descriptor(
g: &mut Graph,
cfg: &DPA1Config,
inputs: DPA1Inputs,
nf: usize,
nloc: usize,
) -> Result<DPA1Descriptor> {
if cfg.attn % cfg.num_heads != 0 {
bail!(
"dpa1: attn ({}) must be divisible by num_heads ({})",
cfg.attn,
cfg.num_heads
);
}
let activation = ActivationKind::parse(&cfg.activation_function)?;
let ng = cfg.ng();
let nnei = cfg.nnei();
let m2 = cfg.axis_neuron;
let tebd = cfg.tebd_dim;
let ntypes = cfg.ntypes;
let rr_shape = g.shape(inputs.env_mat_raw).clone();
if rr_shape.rank() != 4 {
bail!("dpa1: env-matrix input must have rank 4");
}
let davg = g.param(
"descriptor.davg",
Shape::new(&[ntypes, nnei, 4], DType::F32),
);
let dstd = g.param(
"descriptor.dstd",
Shape::new(&[ntypes, nnei, 4], DType::F32),
);
let davg_g = g.gather_(davg, inputs.atype_loc, 0);
let dstd_g = g.gather_(dstd, inputs.atype_loc, 0);
let mut rr = g.sub(inputs.env_mat_raw, davg_g);
rr = g.div(rr, dstd_g);
let mask_4d_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let mask_4d = g.reshape(
inputs.nlist_mask,
vec![nf as i64, nloc as i64, nnei as i64, 1],
mask_4d_shape,
);
rr = g.mul(rr, mask_4d);
let gg_initial = build_pair_embedding(
g, cfg, &inputs, rr, activation, nf, nloc, nnei, ng,
)?;
let attn_result = build_neighbor_gated_attention_with_debug(
g,
cfg,
gg_initial,
inputs.nlist_mask,
rr, inputs.sw,
nf,
nloc,
nnei,
ng,
)?;
let gg = attn_result.0;
let dbg_query = attn_result.1;
let dbg_aw_post = attn_result.2;
let dbg_o = attn_result.3;
let dbg_input_r = attn_result.4;
let gg_t = g.transpose_(gg, vec![0, 1, 3, 2]); let mut gr = g.mm(gg_t, rr); let inv_nnei = scalar_const(g, 1.0 / nnei as f32);
gr = g.mul(gr, inv_nnei);
let gr_lt = g.narrow_(gr, 2, 0, m2);
let gr_lt_t = g.transpose_(gr_lt, vec![0, 1, 3, 2]);
let d = g.mm(gr, gr_lt_t);
let core_dim = ng * m2;
let d_flat_shape = Shape::new(&[nf, nloc, core_dim], DType::F32);
let mut descriptor = g.reshape(
d,
vec![nf as i64, nloc as i64, core_dim as i64],
d_flat_shape,
);
if cfg.concat_output_tebd {
let centre_tebd = g.gather_(inputs.type_embedding, inputs.atype_loc, 0); let concat_shape = Shape::new(&[nf, nloc, core_dim + tebd], DType::F32);
descriptor = g.concat(vec![descriptor, centre_tebd], 2, concat_shape);
}
let gr_vec = g.narrow_(gr, 3, 1, 3); Ok(DPA1Descriptor {
descriptor,
gr: gr_vec,
gg,
debug_attn_query: dbg_query,
debug_attn_aw_post: dbg_aw_post,
debug_attn_o_pre_proj: dbg_o,
debug_input_r: dbg_input_r,
debug_angular_4d: None,
dim_out: cfg.dim_out(),
})
}
fn build_pair_embedding(
g: &mut Graph,
cfg: &DPA1Config,
inputs: &DPA1Inputs,
rr: NodeId,
activation: ActivationKind,
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
) -> Result<NodeId> {
let ss = g.narrow_(rr, 3, 0, 1); let tebd = cfg.tebd_dim;
let centre_emb = g.gather_(inputs.type_embedding, inputs.atype_loc, 0); let centre_4d_shape = Shape::new(&[nf, nloc, 1, tebd], DType::F32);
let centre_4d = g.reshape(
centre_emb,
vec![nf as i64, nloc as i64, 1, tebd as i64],
centre_4d_shape,
);
let zero_centre_full = zero_f32(g, &[nf, nloc, nnei, tebd]);
let centre_b = g.add(centre_4d, zero_centre_full);
let nei_emb = g.gather_(inputs.type_embedding, inputs.nei_atype, 0);
match cfg.tebd_input_mode {
TebdInputMode::Concat => {
let cat_dim = 1
+ tebd
+ if !cfg.type_one_side { tebd } else { 0 };
let ss_concat_shape = Shape::new(&[nf, nloc, nnei, cat_dim], DType::F32);
let ss_concat = if cfg.type_one_side {
g.concat(vec![ss, nei_emb], 3, ss_concat_shape)
} else {
g.concat(vec![ss, nei_emb, centre_b], 3, ss_concat_shape)
};
let prefix = "descriptor.embedding";
let mlp = MlpSpec {
param_prefix: prefix,
in_dim: cat_dim,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
Ok(embedding_mlp(g, &mlp, ss_concat)) }
TebdInputMode::Strip => {
let mlp_main = MlpSpec {
param_prefix: "descriptor.embedding",
in_dim: 1,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let gg_s = embedding_mlp(g, &mlp_main, ss); let rows = match g.shape(inputs.type_embedding).dim(0) {
Dim::Static(n) => n,
_ => bail!("dpa1: type embedding rows must be static"),
};
let (tt_full, in_dim_for_strip, lookup_idx) = if cfg.type_one_side {
let in_dim = tebd;
let mlp_strip = MlpSpec {
param_prefix: "descriptor.embedding_strip",
in_dim,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let tt_full = embedding_mlp(g, &mlp_strip, inputs.type_embedding);
(tt_full, in_dim, inputs.nei_atype)
} else {
let pair = build_pair_table(g, inputs.type_embedding, rows, tebd)?;
let in_dim = 2 * tebd;
let mlp_strip = MlpSpec {
param_prefix: "descriptor.embedding_strip",
in_dim,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let tt_full = embedding_mlp(g, &mlp_strip, pair); let rows_c = scalar_i32(g, rows as i32);
let center_4d_shape = Shape::new(&[nf, nloc, 1], DType::I32);
let center_3d = g.reshape(
inputs.atype_loc,
vec![nf as i64, nloc as i64, 1],
center_4d_shape,
);
let zero_n = i32_zero_tensor(g, &[nf, nloc, nnei]);
let center_b = g.binary(
BinaryOp::Add,
center_3d,
zero_n,
Shape::new(&[nf, nloc, nnei], DType::I32),
);
let center_scaled = g.binary(
BinaryOp::Mul,
center_b,
rows_c,
g.shape(center_b).clone(),
);
let pair_idx = g.binary(
BinaryOp::Add,
center_scaled,
inputs.nei_atype,
g.shape(center_scaled).clone(),
);
(tt_full, in_dim, pair_idx)
};
let _ = in_dim_for_strip;
let flat_total = nf * nloc * nnei;
let lookup_flat = g.reshape(
lookup_idx,
vec![flat_total as i64],
Shape::new(&[flat_total], DType::I32),
);
let gg_t_flat = g.gather_(tt_full, lookup_flat, 0);
let gg_t_shape = Shape::new(&[nf, nloc, nnei, ng], DType::F32);
let mut gg_t = g.reshape(
gg_t_flat,
vec![nf as i64, nloc as i64, nnei as i64, ng as i64],
gg_t_shape,
);
if cfg.smooth {
let sw_4d_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let sw_4d = g.reshape(
inputs.sw,
vec![nf as i64, nloc as i64, nnei as i64, 1],
sw_4d_shape,
);
gg_t = g.mul(gg_t, sw_4d);
}
let prod = g.mul(gg_s, gg_t);
Ok(g.add(prod, gg_s))
}
}
}
fn build_neighbor_gated_attention_with_debug(
g: &mut Graph,
cfg: &DPA1Config,
gg_in: NodeId,
nlist_mask: NodeId,
rr: NodeId,
sw: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
) -> Result<(NodeId, Option<NodeId>, Option<NodeId>, Option<NodeId>, Option<NodeId>)> {
let r_xyz = g.narrow_(rr, 3, 1, 3);
let r_sq = g.mul(r_xyz, r_xyz);
let r_norm_sq_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let r_norm_sq = g.reduce(r_sq, ReduceOp::Sum, vec![3], true, r_norm_sq_shape);
let eps_t = scalar_const(g, 1e-24);
let r_norm_sq_clamped =
g.binary(BinaryOp::Max, r_norm_sq, eps_t, g.shape(r_norm_sq).clone());
let r_norm = g.activation(
Activation::Sqrt,
r_norm_sq_clamped,
g.shape(r_norm_sq_clamped).clone(),
);
let input_r = g.div(r_xyz, r_norm);
let mut x = gg_in;
let mut dbg_q = None;
let mut dbg_aw = None;
let mut dbg_o = None;
for l in 0..cfg.attn_layer {
let prefix = format!("descriptor.attention.layer{l}");
let (new_x, q_in, aw_post, o_pre) = build_attention_layer_with_debug(
g, cfg, &prefix, x, nlist_mask, sw, input_r, nf, nloc, nnei, ng,
)?;
x = new_x;
dbg_q = Some(q_in);
dbg_aw = Some(aw_post);
dbg_o = Some(o_pre);
}
Ok((x, dbg_q, dbg_aw, dbg_o, Some(input_r)))
}
fn build_attention_layer_with_debug(
g: &mut Graph,
cfg: &DPA1Config,
prefix: &str,
x: NodeId,
nlist_mask: NodeId,
sw: NodeId,
input_r: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
) -> Result<(NodeId, NodeId, NodeId, NodeId)> {
let (attn_out, aw_post, o_pre) = build_gated_attention_with_debug(
g, cfg, prefix, x, nlist_mask, sw, input_r, nf, nloc, nnei, ng,
)?;
let q_in = x;
let residual = g.add(x, attn_out);
let gamma_name = format!("{prefix}.attn_layer_norm.gamma");
let beta_name = format!("{prefix}.attn_layer_norm.beta");
let gamma = g.param(&gamma_name, Shape::new(&[ng], DType::F32));
let beta = g.param(&beta_name, Shape::new(&[ng], DType::F32));
let out = g.ln(residual, gamma, beta, cfg.ln_eps);
Ok((out, q_in, aw_post, o_pre))
}
#[allow(dead_code)]
fn build_neighbor_gated_attention(
g: &mut Graph,
cfg: &DPA1Config,
gg_in: NodeId,
nlist_mask: NodeId,
rr: NodeId, sw: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
) -> Result<NodeId> {
let r_xyz = g.narrow_(rr, 3, 1, 3); let r_sq = g.mul(r_xyz, r_xyz);
let r_norm_sq_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let r_norm_sq = g.reduce(
r_sq,
ReduceOp::Sum,
vec![3],
true,
r_norm_sq_shape,
);
let eps_t = scalar_const(g, 1e-24);
let r_norm_sq_clamped = g.binary(
BinaryOp::Max,
r_norm_sq,
eps_t,
g.shape(r_norm_sq).clone(),
);
let r_norm =
g.activation(Activation::Sqrt, r_norm_sq_clamped, g.shape(r_norm_sq_clamped).clone());
let input_r = g.div(r_xyz, r_norm);
let mut x = gg_in;
for l in 0..cfg.attn_layer {
let prefix = format!("descriptor.attention.layer{l}");
x = build_attention_layer(
g, cfg, &prefix, x, nlist_mask, sw, input_r, nf, nloc, nnei, ng,
)?;
}
Ok(x)
}
fn build_attention_layer(
g: &mut Graph,
cfg: &DPA1Config,
prefix: &str,
x: NodeId,
nlist_mask: NodeId,
sw: NodeId,
input_r: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
) -> Result<NodeId> {
let attn_out = build_gated_attention(
g, cfg, prefix, x, nlist_mask, sw, input_r, nf, nloc, nnei, ng,
)?;
let residual = g.add(x, attn_out);
let gamma_name = format!("{prefix}.attn_layer_norm.gamma");
let beta_name = format!("{prefix}.attn_layer_norm.beta");
let gamma = g.param(&gamma_name, Shape::new(&[ng], DType::F32));
let beta = g.param(&beta_name, Shape::new(&[ng], DType::F32));
Ok(g.ln(residual, gamma, beta, cfg.ln_eps))
}
fn build_gated_attention_with_debug(
g: &mut Graph,
cfg: &DPA1Config,
prefix: &str,
x: NodeId,
nlist_mask: NodeId,
sw: NodeId,
input_r: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
) -> Result<(NodeId, NodeId, NodeId)> {
let head_dim = cfg.head_dim();
let h = cfg.num_heads;
let hidden = cfg.attn;
let in_proj_w = g.param(
format!("{prefix}.in_proj.w"),
Shape::new(&[ng, 3 * hidden], DType::F32),
);
let in_proj_b = g.param(
format!("{prefix}.in_proj.b"),
Shape::new(&[3 * hidden], DType::F32),
);
let projected_mm = g.mm(x, in_proj_w);
let projected = g.add(projected_mm, in_proj_b);
let debug_in_proj = projected;
let _ = debug_in_proj;
let q = g.narrow_(projected, 3, 0, head_dim);
let k = g.narrow_(projected, 3, head_dim, head_dim);
let v = g.narrow_(projected, 3, 2 * head_dim, head_dim);
let inner_dim = head_dim / h;
let reshape_to_heads = |g: &mut Graph, t: NodeId| -> NodeId {
let r = g.reshape(
t,
vec![(nf * nloc) as i64, nnei as i64, h as i64, inner_dim as i64],
Shape::new(&[nf * nloc, nnei, h, inner_dim], DType::F32),
);
g.transpose_(r, vec![0, 2, 1, 3])
};
let mut q = reshape_to_heads(g, q);
let mut k = reshape_to_heads(g, k);
let mut v = reshape_to_heads(g, v);
if cfg.normalize {
q = l2_normalize(g, q);
k = l2_normalize(g, k);
v = l2_normalize(g, v);
}
let sf = if cfg.scaling_factor == 0.0 { 1.0 } else { cfg.scaling_factor };
let scaling = 1.0 / (head_dim as f32 * sf).sqrt();
let s = scalar_const(g, scaling);
q = g.mul(q, s);
let k_t = g.transpose_(k, vec![0, 1, 3, 2]);
let mut aw = g.mm(q, k_t);
if cfg.smooth {
let sw_flat_shape = Shape::new(&[nf * nloc, nnei], DType::F32);
let sw_flat = g.reshape(sw, vec![(nf * nloc) as i64, nnei as i64], sw_flat_shape);
let sw_b = g.reshape(
sw_flat,
vec![(nf * nloc) as i64, 1, 1, nnei as i64],
Shape::new(&[nf * nloc, 1, 1, nnei], DType::F32),
);
let sw_a = g.reshape(
sw_flat,
vec![(nf * nloc) as i64, 1, nnei as i64, 1],
Shape::new(&[nf * nloc, 1, nnei, 1], DType::F32),
);
let sw_prod = g.mul(sw_a, sw_b);
let aw_scaled = g.mul(aw, sw_prod);
let one = scalar_const(g, 1.0);
let sw_prod_shape = g.shape(sw_prod).clone();
let sw_minus_one = g.binary(
rlx_ir::op::BinaryOp::Sub,
sw_prod,
one,
sw_prod_shape,
);
let twenty = scalar_const(g, 20.0);
let bias_shape = g.shape(sw_minus_one).clone();
let bias = g.binary(
rlx_ir::op::BinaryOp::Mul,
twenty,
sw_minus_one,
bias_shape,
);
aw = g.add(aw_scaled, bias);
} else {
let mask_flat = g.reshape(
nlist_mask,
vec![(nf * nloc) as i64, 1, 1, nnei as i64],
Shape::new(&[nf * nloc, 1, 1, nnei], DType::F32),
);
let one = scalar_const(g, 1.0);
let inv = g.sub(one, mask_flat);
let neg_huge = scalar_const(g, -1e30);
let add = g.mul(inv, neg_huge);
aw = g.add(aw, add);
}
let aw_shape = g.shape(aw).clone();
aw = g.softmax(aw, -1, aw_shape);
let mask_q_shape = Shape::new(&[nf * nloc, 1, nnei, 1], DType::F32);
let mut combined = g.reshape(
nlist_mask,
vec![(nf * nloc) as i64, 1, nnei as i64, 1],
mask_q_shape,
);
if cfg.smooth {
let sw_flat = g.reshape(
sw,
vec![(nf * nloc) as i64, nnei as i64],
Shape::new(&[nf * nloc, nnei], DType::F32),
);
let sw_b = g.reshape(
sw_flat,
vec![(nf * nloc) as i64, 1, 1, nnei as i64],
Shape::new(&[nf * nloc, 1, 1, nnei], DType::F32),
);
let sw_a = g.reshape(
sw_flat,
vec![(nf * nloc) as i64, 1, nnei as i64, 1],
Shape::new(&[nf * nloc, 1, nnei, 1], DType::F32),
);
let m_sw_a_shape = Shape::new(&[nf * nloc, 1, nnei, 1], DType::F32);
let m_sw_a = g.binary(BinaryOp::Mul, combined, sw_a, m_sw_a_shape);
let m_full_shape = Shape::new(&[nf * nloc, 1, nnei, nnei], DType::F32);
combined = g.binary(BinaryOp::Mul, m_sw_a, sw_b, m_full_shape);
}
let mut dbg_angular: Option<NodeId> = None;
if cfg.attn_dotr {
let r_flat_shape = Shape::new(&[nf * nloc, nnei, 3], DType::F32);
let r_flat = g.reshape(
input_r,
vec![(nf * nloc) as i64, nnei as i64, 3],
r_flat_shape,
);
let r_flat_t = g.transpose_(r_flat, vec![0, 2, 1]);
let angular = g.mm(r_flat, r_flat_t);
let angular_4d = g.reshape(
angular,
vec![(nf * nloc) as i64, 1, nnei as i64, nnei as i64],
Shape::new(&[nf * nloc, 1, nnei, nnei], DType::F32),
);
dbg_angular = Some(angular_4d);
let comb_shape = Shape::new(&[nf * nloc, 1, nnei, nnei], DType::F32);
combined = g.binary(BinaryOp::Mul, combined, angular_4d, comb_shape);
}
let aw_shape = g.shape(aw).clone();
aw = g.binary(BinaryOp::Mul, aw, combined, aw_shape);
let aw_post = aw;
let o = g.mm(aw, v);
let o_t = g.transpose_(o, vec![0, 2, 1, 3]);
let o_reshape_shape = Shape::new(&[nf * nloc, nnei, hidden], DType::F32);
let o_flat = g.reshape(
o_t,
vec![(nf * nloc) as i64, nnei as i64, hidden as i64],
o_reshape_shape,
);
let op_w = g.param(
format!("{prefix}.out_proj.w"),
Shape::new(&[hidden, ng], DType::F32),
);
let op_b = g.param(format!("{prefix}.out_proj.b"), Shape::new(&[ng], DType::F32));
let mm = g.mm(o_flat, op_w);
let out_flat = g.add(mm, op_b);
let out_shape = Shape::new(&[nf, nloc, nnei, ng], DType::F32);
let final_out = g.reshape(
out_flat,
vec![nf as i64, nloc as i64, nnei as i64, ng as i64],
out_shape,
);
let _ = dbg_angular;
Ok((final_out, aw_post, o_flat))
}
#[allow(dead_code)]
fn build_gated_attention(
g: &mut Graph,
cfg: &DPA1Config,
prefix: &str,
x: NodeId, nlist_mask: NodeId,
sw: NodeId,
input_r: NodeId, nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
) -> Result<NodeId> {
let head_dim = cfg.head_dim();
let h = cfg.num_heads;
let hidden = cfg.attn;
let in_proj_w = g.param(
format!("{prefix}.in_proj.w"),
Shape::new(&[ng, 3 * hidden], DType::F32),
);
let in_proj_b = g.param(
format!("{prefix}.in_proj.b"),
Shape::new(&[3 * hidden], DType::F32),
);
let projected_mm = g.mm(x, in_proj_w);
let projected = g.add(projected_mm, in_proj_b);
let q = g.narrow_(projected, 3, 0, head_dim);
let k = g.narrow_(projected, 3, head_dim, head_dim);
let v = g.narrow_(projected, 3, 2 * head_dim, head_dim);
let inner_dim = head_dim / h;
let reshape_to_heads = |g: &mut Graph, t: NodeId| -> NodeId {
let r = g.reshape(
t,
vec![
(nf * nloc) as i64,
nnei as i64,
h as i64,
inner_dim as i64,
],
Shape::new(&[nf * nloc, nnei, h, inner_dim], DType::F32),
);
g.transpose_(r, vec![0, 2, 1, 3]) };
let mut q = reshape_to_heads(g, q);
let mut k = reshape_to_heads(g, k);
let mut v = reshape_to_heads(g, v);
if cfg.normalize {
q = l2_normalize(g, q);
k = l2_normalize(g, k);
v = l2_normalize(g, v);
}
let sf = if cfg.scaling_factor == 0.0 { 1.0 } else { cfg.scaling_factor };
let scaling = 1.0 / (head_dim as f32 * sf).sqrt();
let s = scalar_const(g, scaling);
q = g.mul(q, s);
let k_t = g.transpose_(k, vec![0, 1, 3, 2]); let mut aw = g.mm(q, k_t);
if cfg.smooth {
let sw_flat_shape = Shape::new(&[nf * nloc, nnei], DType::F32);
let sw_flat = g.reshape(
sw,
vec![(nf * nloc) as i64, nnei as i64],
sw_flat_shape,
);
let sw_b = g.reshape(
sw_flat,
vec![(nf * nloc) as i64, 1, 1, nnei as i64],
Shape::new(&[nf * nloc, 1, 1, nnei], DType::F32),
);
let sw_a = g.reshape(
sw_flat,
vec![(nf * nloc) as i64, 1, nnei as i64, 1],
Shape::new(&[nf * nloc, 1, nnei, 1], DType::F32),
);
let sw_prod = g.mul(sw_a, sw_b);
let aw_scaled = g.mul(aw, sw_prod);
let one = scalar_const(g, 1.0);
let sw_prod_shape = g.shape(sw_prod).clone();
let sw_minus_one = g.binary(BinaryOp::Sub, sw_prod, one, sw_prod_shape);
let twenty = scalar_const(g, 20.0);
let bias_shape = g.shape(sw_minus_one).clone();
let bias = g.binary(BinaryOp::Mul, twenty, sw_minus_one, bias_shape);
aw = g.add(aw_scaled, bias);
} else {
let mask_flat = g.reshape(
nlist_mask,
vec![(nf * nloc) as i64, 1, 1, nnei as i64],
Shape::new(&[nf * nloc, 1, 1, nnei], DType::F32),
);
let one = scalar_const(g, 1.0);
let inv = g.sub(one, mask_flat);
let neg_huge = scalar_const(g, -1e30);
let add = g.mul(inv, neg_huge);
aw = g.add(aw, add);
}
let aw_shape = g.shape(aw).clone();
aw = g.softmax(aw, -1, aw_shape);
{
let mask_q_shape = Shape::new(&[nf * nloc, 1, nnei, 1], DType::F32);
let mask_q = g.reshape(
nlist_mask,
vec![(nf * nloc) as i64, 1, nnei as i64, 1],
mask_q_shape,
);
aw = g.mul(aw, mask_q);
}
if cfg.smooth {
let sw_flat = g.reshape(
sw,
vec![(nf * nloc) as i64, nnei as i64],
Shape::new(&[nf * nloc, nnei], DType::F32),
);
let sw_b = g.reshape(
sw_flat,
vec![(nf * nloc) as i64, 1, 1, nnei as i64],
Shape::new(&[nf * nloc, 1, 1, nnei], DType::F32),
);
let sw_a = g.reshape(
sw_flat,
vec![(nf * nloc) as i64, 1, nnei as i64, 1],
Shape::new(&[nf * nloc, 1, nnei, 1], DType::F32),
);
aw = g.mul(aw, sw_a);
aw = g.mul(aw, sw_b);
}
if cfg.attn_dotr {
let r_flat_shape = Shape::new(&[nf * nloc, nnei, 3], DType::F32);
let r_flat = g.reshape(
input_r,
vec![(nf * nloc) as i64, nnei as i64, 3],
r_flat_shape,
);
let r_flat_t = g.transpose_(r_flat, vec![0, 2, 1]); let angular = g.mm(r_flat, r_flat_t); let angular_4d = g.reshape(
angular,
vec![(nf * nloc) as i64, 1, nnei as i64, nnei as i64],
Shape::new(&[nf * nloc, 1, nnei, nnei], DType::F32),
);
aw = g.mul(aw, angular_4d);
}
let o = g.mm(aw, v);
let o_t = g.transpose_(o, vec![0, 2, 1, 3]); let o_reshape_shape = Shape::new(&[nf * nloc, nnei, hidden], DType::F32);
let o_flat = g.reshape(
o_t,
vec![(nf * nloc) as i64, nnei as i64, hidden as i64],
o_reshape_shape,
);
let op_w = g.param(
format!("{prefix}.out_proj.w"),
Shape::new(&[hidden, ng], DType::F32),
);
let op_b = g.param(format!("{prefix}.out_proj.b"), Shape::new(&[ng], DType::F32));
let mm = g.mm(o_flat, op_w);
let out_flat = g.add(mm, op_b);
let out_shape = Shape::new(&[nf, nloc, nnei, ng], DType::F32);
Ok(g.reshape(
out_flat,
vec![nf as i64, nloc as i64, nnei as i64, ng as i64],
out_shape,
))
}
fn l2_normalize(g: &mut Graph, x: NodeId) -> NodeId {
let x_sq = g.mul(x, x);
let in_shape = g.shape(x).clone();
let mut dims: Vec<Dim> = in_shape.dims().to_vec();
let last = dims.len() - 1;
dims[last] = Dim::Static(1);
let red_shape = Shape::from_dims(&dims, in_shape.dtype());
let sum_sq = g.reduce(x_sq, ReduceOp::Sum, vec![last], true, red_shape);
let eps = scalar_const(g, 1e-12);
let summed = g.binary(BinaryOp::Max, sum_sq, eps, g.shape(sum_sq).clone());
let norm = g.activation(Activation::Sqrt, summed, g.shape(summed).clone());
g.div(x, norm)
}
fn zero_f32(g: &mut Graph, shape: &[usize]) -> NodeId {
let n: usize = shape.iter().product();
g.add_node(
rlx_ir::op::Op::Constant {
data: vec![0u8; n * 4],
},
vec![],
Shape::new(shape, DType::F32),
)
}
fn i32_zero_tensor(g: &mut Graph, shape: &[usize]) -> NodeId {
let n: usize = shape.iter().product();
g.add_node(
rlx_ir::op::Op::Constant {
data: vec![0u8; n * 4],
},
vec![],
Shape::new(shape, DType::I32),
)
}
fn scalar_i32(g: &mut Graph, v: i32) -> NodeId {
let bytes = v.to_le_bytes().to_vec();
g.add_node(
rlx_ir::op::Op::Constant { data: bytes },
vec![],
Shape::new(&[1], DType::I32),
)
}
fn build_pair_table(
g: &mut Graph,
table: NodeId,
rows: usize,
tebd: usize,
) -> Result<NodeId> {
let mut i_data = Vec::with_capacity(rows * rows);
let mut j_data = Vec::with_capacity(rows * rows);
for i in 0..rows {
for j in 0..rows {
i_data.push(i as i32);
j_data.push(j as i32);
}
}
let i_bytes: Vec<u8> = i_data.iter().flat_map(|v| v.to_le_bytes()).collect();
let j_bytes: Vec<u8> = j_data.iter().flat_map(|v| v.to_le_bytes()).collect();
let i_idx = g.add_node(
rlx_ir::op::Op::Constant { data: i_bytes },
vec![],
Shape::new(&[rows * rows], DType::I32),
);
let j_idx = g.add_node(
rlx_ir::op::Op::Constant { data: j_bytes },
vec![],
Shape::new(&[rows * rows], DType::I32),
);
let t_i = g.gather_(table, i_idx, 0);
let t_j = g.gather_(table, j_idx, 0);
let out_shape = Shape::new(&[rows * rows, 2 * tebd], DType::F32);
Ok(g.concat(vec![t_i, t_j], 1, out_shape))
}
#[allow(dead_code)]
fn _keep_nn_imports(_a: ActivationKind, _b: DenseLayerSpec, _c: fn(_:&mut Graph,_:&DenseLayerSpec,_:NodeId)->NodeId) {
let _ = apply_activation;
let _ = dense_layer;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dpa1_concat_builds() {
let cfg = DPA1Config {
rcut: 6.0,
rcut_smth: 0.5,
sel: vec![20, 30],
ntypes: 2,
neuron: vec![8, 16, 32],
axis_neuron: 4,
tebd_dim: 8,
tebd_input_mode: TebdInputMode::Concat,
resnet_dt: false,
attn: 16,
attn_layer: 1,
attn_dotr: true,
num_heads: 2,
normalize: true,
scaling_factor: 1.0,
ln_eps: 1e-5,
smooth: true,
type_one_side: false,
activation_function: "tanh".into(),
concat_output_tebd: true,
};
let mut g = Graph::new("dpa1_concat");
let nf = 1;
let nloc = 4;
let nnei = cfg.nnei();
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let nei_atype = g.input("nei_atype", Shape::new(&[nf, nloc, nnei], DType::I32));
let nlist_mask = g.input("nm", Shape::new(&[nf, nloc, nnei], DType::F32));
let sw = g.input("sw", Shape::new(&[nf, nloc, nnei], DType::F32));
let t_table = g.param(
"type_embed.table",
Shape::new(&[cfg.ntypes + 1, cfg.tebd_dim], DType::F32),
);
let inputs = DPA1Inputs {
env_mat_raw: env_mat,
atype_loc: atype,
nei_atype,
nlist_mask,
sw,
type_embedding: t_table,
};
let out = build_dpa1_descriptor(&mut g, &cfg, inputs, nf, nloc).expect("build");
assert_eq!(out.dim_out, cfg.dim_out());
assert!(g.len() > 50);
}
}