use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::{Activation, BinaryOp, ReduceOp};
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};
use crate::nn::{scalar_const, ActivationKind};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RepformersConfig {
pub rcut: f64,
pub rcut_smth: f64,
pub sel: usize,
pub ntypes: usize,
#[serde(default = "default_nlayers")]
pub nlayers: usize,
#[serde(default = "default_g1_dim")]
pub g1_dim: usize,
#[serde(default = "default_g2_dim")]
pub g2_dim: usize,
#[serde(default = "default_axis_neuron")]
pub axis_neuron: usize,
#[serde(default = "default_true")]
pub update_g1_has_conv: bool,
#[serde(default = "default_true")]
pub update_g1_has_drrd: bool,
#[serde(default = "default_true")]
pub update_g1_has_grrg: bool,
#[serde(default = "default_true")]
pub update_g1_has_attn: bool,
#[serde(default = "default_true")]
pub update_g2_has_g1g1: bool,
#[serde(default = "default_true")]
pub update_g2_has_attn: bool,
#[serde(default)]
pub update_h2: bool,
#[serde(default = "default_attn1_hidden")]
pub attn1_hidden: usize,
#[serde(default = "default_attn1_nhead")]
pub attn1_nhead: usize,
#[serde(default = "default_attn2_hidden")]
pub attn2_hidden: usize,
#[serde(default = "default_attn2_nhead")]
pub attn2_nhead: usize,
#[serde(default)]
pub attn2_has_gate: bool,
#[serde(default = "default_activation")]
pub activation_function: String,
#[serde(default = "default_update_style")]
pub update_style: String,
#[serde(default = "default_true")]
pub smooth: bool,
#[serde(default = "default_true")]
pub use_sqrt_nnei: bool,
#[serde(default = "default_true")]
pub g1_out_conv: bool,
#[serde(default = "default_true")]
pub g1_out_mlp: bool,
#[serde(default = "default_ln_eps")]
pub ln_eps: f32,
#[serde(default = "default_epsilon")]
pub epsilon: f32,
}
fn default_nlayers() -> usize {
3
}
fn default_g1_dim() -> usize {
128
}
fn default_g2_dim() -> usize {
16
}
fn default_axis_neuron() -> usize {
4
}
fn default_attn1_hidden() -> usize {
64
}
fn default_attn1_nhead() -> usize {
4
}
fn default_attn2_hidden() -> usize {
16
}
fn default_attn2_nhead() -> usize {
4
}
fn default_activation() -> String {
"tanh".into()
}
fn default_update_style() -> String {
"res_avg".into()
}
fn default_true() -> bool {
true
}
fn default_ln_eps() -> f32 {
1e-5
}
fn default_epsilon() -> f32 {
1e-4
}
impl RepformersConfig {
pub fn ng1(&self) -> usize {
self.g1_dim
}
pub fn ng2(&self) -> usize {
self.g2_dim
}
}
pub struct RepformersInputs {
pub g1_ext: NodeId,
pub env_mat_raw: NodeId,
pub nlist: NodeId,
pub nlist_mask: NodeId,
pub sw: NodeId,
pub atype_loc: NodeId,
pub mapping: Option<NodeId>,
}
pub struct RepformersOutputs {
pub g1: NodeId, pub g2: NodeId, pub h2: NodeId, pub rot_mat: NodeId, pub sw: NodeId,
}
pub fn build_repformers(
g: &mut Graph,
cfg: &RepformersConfig,
inputs: RepformersInputs,
nf: usize,
nloc: usize,
nall: usize,
) -> Result<RepformersOutputs> {
if cfg.update_style != "res_avg" {
bail!(
"repformers: only update_style=\"res_avg\" is implemented (got \"{}\")",
cfg.update_style
);
}
let activation = ActivationKind::parse(&cfg.activation_function)?;
let _ng1 = cfg.ng1();
let ng2 = cfg.ng2();
let nnei = cfg.sel;
let davg = g.param(
"repformers.davg",
Shape::new(&[cfg.ntypes, nnei, 4], DType::F32),
);
let dstd = g.param(
"repformers.dstd",
Shape::new(&[cfg.ntypes, nnei, 4], DType::F32),
);
let davg_g = g.gather_(davg, inputs.atype_loc, 0);
let dstd_g = g.gather_(dstd, inputs.atype_loc, 0);
let mut dmatrix = g.sub(inputs.env_mat_raw, davg_g);
dmatrix = g.div(dmatrix, dstd_g);
let radial = g.narrow_(dmatrix, 3, 0, 1); let g2_w = g.param(
"repformers.g2_embd.w",
Shape::new(&[1, ng2], DType::F32),
);
let g2_b = g.param("repformers.g2_embd.b", Shape::new(&[ng2], DType::F32));
let g2_pre = g.mm(radial, g2_w);
let g2_pre = g.add(g2_pre, g2_b);
let mut g2 = crate::nn::apply_activation(g, activation, g2_pre);
let mut h2 = g.narrow_(dmatrix, 3, 1, 3);
let g1_ext_act = crate::nn::apply_activation(g, activation, inputs.g1_ext);
let mut g1 = g.narrow_(g1_ext_act, 1, 0, nloc);
let mask_4d_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let mask_4d = g.reshape(
inputs.nlist_mask,
vec![nf as i64, nloc as i64, nnei as i64, 1],
mask_4d_shape,
);
g2 = g.mul(g2, mask_4d);
let mut g1_ext_cur = g1_ext_act;
for l in 0..cfg.nlayers {
let is_last = l == cfg.nlayers - 1;
let prefix = format!("repformers.layer{l}");
let (g1_new, g2_new, h2_new) = build_repformer_layer(
g, cfg, &prefix, activation, g1, g2, h2, inputs.nlist, inputs.nlist_mask,
inputs.sw, g1_ext_cur, nf, nloc, nnei, !is_last,
)?;
g1 = g1_new;
g2 = g2_new;
h2 = h2_new;
if l + 1 < cfg.nlayers {
g1_ext_cur = if nall == nloc {
g1
} else {
let mapping = inputs.mapping.ok_or_else(|| {
anyhow::anyhow!("repformers: mapping is required when nall != nloc")
})?;
g.gather_(g1, mapping, 1)
};
}
}
let hg = build_cal_hg(g, g2, h2, inputs.sw, cfg, nf, nloc, nnei)?; let rot_mat = g.transpose_(hg, vec![0, 1, 3, 2]);
Ok(RepformersOutputs {
g1,
g2,
h2,
rot_mat,
sw: inputs.sw,
})
}
fn build_repformer_layer(
g: &mut Graph,
cfg: &RepformersConfig,
prefix: &str,
activation: ActivationKind,
g1: NodeId,
g2: NodeId,
h2: NodeId,
nlist: NodeId,
nlist_mask: NodeId,
sw: NodeId,
g1_ext: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
update_chnnl_2: bool,
) -> Result<(NodeId, NodeId, NodeId)> {
let ng1 = cfg.g1_dim;
let ng2 = cfg.g2_dim;
let axis = cfg.axis_neuron;
let gg1 = make_nei_g1(g, g1_ext, nlist, nf, nloc, nnei, ng1);
let do_chnnl_2 = update_chnnl_2;
let do_g2_g1g1 = do_chnnl_2 && cfg.update_g2_has_g1g1;
let do_g2_attn = do_chnnl_2 && cfg.update_g2_has_attn;
let do_h2 = do_chnnl_2 && cfg.update_h2;
let (g2_new, _aag_dbg) = if do_chnnl_2 {
let g2_1 = linear_act(g, &format!("{prefix}.linear2"), g2, ng2, ng2, activation);
let g2_g1g1 = if do_g2_g1g1 {
let g1g1 = update_g2_g1g1(g, g1, gg1, nlist_mask, sw, nf, nloc, nnei, ng1);
Some(linear(g, &format!("{prefix}.proj_g1g1g2"), g1g1, ng1, ng2))
} else {
None
};
let aag = if do_g2_attn || do_h2 {
Some(build_atten2_map(
g, cfg, &format!("{prefix}.attn2g_map"), g2, h2, nlist_mask, sw,
nf, nloc, nnei,
)?)
} else {
None
};
let g2_attn = if do_g2_attn {
let g2_mh = build_atten2_mh_apply(
g, cfg, &format!("{prefix}.attn2_mh_apply"), aag.unwrap(), g2,
nf, nloc, nnei, ng2,
);
let gamma = g.param(
format!("{prefix}.attn2_lm.w"),
Shape::new(&[ng2], DType::F32),
);
let beta = g.param(
format!("{prefix}.attn2_lm.b"),
Shape::new(&[ng2], DType::F32),
);
Some(g.ln(g2_mh, gamma, beta, cfg.ln_eps))
} else {
None
};
let mut g2_terms = vec![g2, g2_1];
if let Some(t) = g2_g1g1 {
g2_terms.push(t);
}
if let Some(t) = g2_attn {
g2_terms.push(t);
}
(list_update_res_avg(g, &g2_terms), aag)
} else {
(g2, None)
};
let h2_new = if do_h2 {
let aag = _aag_dbg.expect("aag must be set if update_h2");
let h2_upd = update_h2(g, &format!("{prefix}.attn2_ev_apply"), h2, aag, nf, nloc, nnei);
list_update_res_avg(g, &[h2, h2_upd])
} else {
h2
};
let mut g1_mlp: Vec<NodeId> = Vec::new();
if !cfg.g1_out_mlp {
g1_mlp.push(g1);
}
let g1_conv = if cfg.update_g1_has_conv {
Some(update_g1_conv(
g, &format!("{prefix}"), cfg, gg1, g2, nlist_mask, sw, nf, nloc, nnei, ng1, ng2,
))
} else {
None
};
if !cfg.g1_out_conv {
if let Some(c) = g1_conv {
g1_mlp.push(c);
}
}
if cfg.update_g1_has_grrg {
let grrg =
symmetrization_op(g, g2, h2, nlist_mask, sw, cfg, axis, nf, nloc, nnei, ng2);
g1_mlp.push(grrg);
}
if cfg.update_g1_has_drrd {
let drrd =
symmetrization_op(g, gg1, h2, nlist_mask, sw, cfg, axis, nf, nloc, nnei, ng1);
g1_mlp.push(drrd);
}
let concat_dim: usize = g1_mlp
.iter()
.map(|n| match g.shape(*n).dim(2) {
rlx_ir::Dim::Static(d) => d,
_ => 0,
})
.sum();
let mlp_in_shape = Shape::new(&[nf, nloc, concat_dim], DType::F32);
let g1_mlp_concat = g.concat(g1_mlp, 2, mlp_in_shape);
let g1_1 = linear_act(
g,
&format!("{prefix}.linear1"),
g1_mlp_concat,
concat_dim,
ng1,
activation,
);
let mut g1_updates = vec![g1];
if cfg.g1_out_mlp {
let g1_self = linear_act(
g,
&format!("{prefix}.g1_self_mlp"),
g1,
ng1,
ng1,
activation,
);
g1_updates.push(g1_self);
}
if cfg.g1_out_conv {
if let Some(c) = g1_conv {
g1_updates.push(c);
}
}
g1_updates.push(g1_1);
if cfg.update_g1_has_attn {
let loc_attn = build_local_atten(
g, cfg, &format!("{prefix}.loc_attn"), g1, gg1, nlist_mask, sw,
nf, nloc, nnei, ng1,
)?;
g1_updates.push(loc_attn);
}
let g1_new = list_update_res_avg(g, &g1_updates);
Ok((g1_new, g2_new, h2_new))
}
fn linear(
g: &mut Graph,
prefix: &str,
x: NodeId,
in_dim: usize,
out_dim: usize,
) -> NodeId {
let w = g.param(
format!("{prefix}.w"),
Shape::new(&[in_dim, out_dim], DType::F32),
);
let b = g.param(format!("{prefix}.b"), Shape::new(&[out_dim], DType::F32));
let mm = g.mm(x, w);
g.add(mm, b)
}
fn linear_act(
g: &mut Graph,
prefix: &str,
x: NodeId,
in_dim: usize,
out_dim: usize,
activation: ActivationKind,
) -> NodeId {
let pre = linear(g, prefix, x, in_dim, out_dim);
crate::nn::apply_activation(g, activation, pre)
}
fn make_nei_g1(
g: &mut Graph,
g1_ext: NodeId,
nlist: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng1: usize,
) -> NodeId {
let nlist_2d_shape = Shape::new(&[nf, nloc * nnei], DType::I32);
let nlist_2d = g.reshape(
nlist,
vec![nf as i64, (nloc * nnei) as i64],
nlist_2d_shape,
);
let gathered = g.gather_(g1_ext, nlist_2d, 1);
let out_shape = Shape::new(&[nf, nloc, nnei, ng1], DType::F32);
g.reshape(
gathered,
vec![nf as i64, nloc as i64, nnei as i64, ng1 as i64],
out_shape,
)
}
fn update_g2_g1g1(
g: &mut Graph,
g1: NodeId, gg1: NodeId, nlist_mask: NodeId,
sw: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng1: usize,
) -> NodeId {
let g1_4d_shape = Shape::new(&[nf, nloc, 1, ng1], DType::F32);
let g1_4d = g.reshape(
g1,
vec![nf as i64, nloc as i64, 1, ng1 as i64],
g1_4d_shape,
);
let prod = g.mul(g1_4d, gg1); let mask_4d_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let mask_4d = g.reshape(
nlist_mask,
vec![nf as i64, nloc as i64, nnei as i64, 1],
mask_4d_shape,
);
let masked = g.mul(prod, mask_4d);
let sw_4d_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let sw_4d = g.reshape(
sw,
vec![nf as i64, nloc as i64, nnei as i64, 1],
sw_4d_shape,
);
g.mul(masked, sw_4d)
}
fn update_g1_conv(
g: &mut Graph,
prefix: &str,
cfg: &RepformersConfig,
gg1: NodeId, g2: NodeId, nlist_mask: NodeId,
sw: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng1: usize,
ng2: usize,
) -> NodeId {
let mask_4d = g.reshape(
nlist_mask,
vec![nf as i64, nloc as i64, nnei as i64, 1],
Shape::new(&[nf, nloc, nnei, 1], DType::F32),
);
let sw_4d = g.reshape(
sw,
vec![nf as i64, nloc as i64, nnei as i64, 1],
Shape::new(&[nf, nloc, nnei, 1], DType::F32),
);
let gg1_m = g.mul(gg1, mask_4d);
let gg1_ms = if cfg.smooth { g.mul(gg1_m, sw_4d) } else { gg1_m };
let proj_w = g.param(
format!("{prefix}.proj_g1g2.w"),
Shape::new(&[ng2, ng1], DType::F32),
);
let g2_proj = g.mm(g2, proj_w);
let prod = g.mul(g2_proj, gg1_ms); let sum_shape = Shape::new(&[nf, nloc, ng1], DType::F32);
let summed = g.reduce(prod, ReduceOp::Sum, vec![2], false, sum_shape);
let inv_n = scalar_const(g, 1.0 / nnei as f32);
g.mul(summed, inv_n)
}
fn update_h2(
g: &mut Graph,
prefix: &str,
h2: NodeId,
aag: NodeId, nf: usize,
nloc: usize,
nnei: usize,
) -> NodeId {
let nh = match g.shape(aag).dim(4) {
rlx_ir::Dim::Static(n) => n,
_ => panic!("aag head dim must be static"),
};
let aag_perm = g.transpose_(aag, vec![0, 1, 4, 2, 3]); let h2_5d_shape = Shape::new(&[nf, nloc, 1, nnei, 3], DType::F32);
let h2_5d = g.reshape(
h2,
vec![nf as i64, nloc as i64, 1, nnei as i64, 3],
h2_5d_shape,
);
let mut h2_tiles = Vec::with_capacity(nh);
for _ in 0..nh {
h2_tiles.push(h2_5d);
}
let h2_tiled_shape = Shape::new(&[nf, nloc, nh, nnei, 3], DType::F32);
let h2_tiled = g.concat(h2_tiles, 2, h2_tiled_shape);
let ret = g.mm(aag_perm, h2_tiled); let ret_perm = g.transpose_(ret, vec![0, 1, 3, 4, 2]); let head_w = g.param(
format!("{prefix}.head_map.w"),
Shape::new(&[nh, 1], DType::F32),
);
let mapped = g.mm(ret_perm, head_w); let out_shape = Shape::new(&[nf, nloc, nnei, 3], DType::F32);
g.reshape(
mapped,
vec![nf as i64, nloc as i64, nnei as i64, 3],
out_shape,
)
}
fn build_cal_hg(
g: &mut Graph,
val: NodeId, h: NodeId, sw: NodeId,
cfg: &RepformersConfig,
nf: usize,
nloc: usize,
nnei: usize,
) -> Result<NodeId> {
let _ = nnei;
let sw_4d = g.reshape(
sw,
vec![nf as i64, nloc as i64, nnei as i64, 1],
Shape::new(&[nf, nloc, nnei, 1], DType::F32),
);
let g_masked = if cfg.smooth { g.mul(val, sw_4d) } else { val };
let h_t = g.transpose_(h, vec![0, 1, 3, 2]); let hg = g.mm(h_t, g_masked); let inv = if cfg.use_sqrt_nnei {
scalar_const(g, 1.0 / (nnei as f32).sqrt())
} else {
scalar_const(g, 1.0 / nnei as f32)
};
Ok(g.mul(hg, inv))
}
fn symmetrization_op(
g: &mut Graph,
val: NodeId, h: NodeId, nlist_mask: NodeId,
sw: NodeId,
cfg: &RepformersConfig,
axis: usize,
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
) -> NodeId {
let _ = nlist_mask;
let hg = build_cal_hg(g, val, h, sw, cfg, nf, nloc, nnei).unwrap(); let hgm = g.narrow_(hg, 3, 0, axis); let hgm_t = g.transpose_(hgm, vec![0, 1, 3, 2]); let grrg = g.mm(hgm_t, hg); let inv_3 = scalar_const(g, 1.0 / 3.0);
let grrg = g.mul(grrg, inv_3);
let out_shape = Shape::new(&[nf, nloc, axis * ng], DType::F32);
g.reshape(grrg, vec![nf as i64, nloc as i64, (axis * ng) as i64], out_shape)
}
fn list_update_res_avg(g: &mut Graph, list: &[NodeId]) -> NodeId {
let n = list.len();
let mut acc = list[0];
for &x in &list[1..] {
acc = g.add(acc, x);
}
let inv_sqrt = scalar_const(g, 1.0 / (n as f32).sqrt());
g.mul(acc, inv_sqrt)
}
fn build_atten2_map(
g: &mut Graph,
cfg: &RepformersConfig,
prefix: &str,
g2: NodeId,
h2: NodeId,
nlist_mask: NodeId,
sw: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
) -> Result<NodeId> {
let nh = cfg.attn2_nhead;
let nd = cfg.attn2_hidden;
let ng2 = cfg.g2_dim;
let w = g.param(
format!("{prefix}.mapqk.w"),
Shape::new(&[ng2, 2 * nh * nd], DType::F32),
);
let qk = g.mm(g2, w); let qk_reshape_shape = Shape::new(&[nf, nloc, nnei, nd, 2 * nh], DType::F32);
let qk_r = g.reshape(
qk,
vec![nf as i64, nloc as i64, nnei as i64, nd as i64, (2 * nh) as i64],
qk_reshape_shape,
);
let qk_t = g.transpose_(qk_r, vec![0, 1, 4, 2, 3]); let q = g.narrow_(qk_t, 2, 0, nh); let k = g.narrow_(qk_t, 2, nh, nh);
let k_t = g.transpose_(k, vec![0, 1, 2, 4, 3]); let mut attnw = g.mm(q, k_t); let scale = scalar_const(g, 1.0 / (nd as f32).sqrt());
attnw = g.mul(attnw, scale);
if cfg.attn2_has_gate {
let h2_t = g.transpose_(h2, vec![0, 1, 3, 2]); let gate = g.mm(h2, h2_t); let gate_5d_shape = Shape::new(&[nf, nloc, 1, nnei, nnei], DType::F32);
let gate_5d = g.reshape(
gate,
vec![nf as i64, nloc as i64, 1, nnei as i64, nnei as i64],
gate_5d_shape,
);
attnw = g.mul(attnw, gate_5d);
}
if cfg.smooth {
let sw_b = g.reshape(
sw,
vec![nf as i64, nloc as i64, 1, 1, nnei as i64],
Shape::new(&[nf, nloc, 1, 1, nnei], DType::F32),
);
let sw_a = g.reshape(
sw,
vec![nf as i64, nloc as i64, 1, nnei as i64, 1],
Shape::new(&[nf, nloc, 1, nnei, 1], DType::F32),
);
let sw_prod = g.mul(sw_a, sw_b);
let aw_scaled = g.mul(attnw, sw_prod);
let one = scalar_const(g, 1.0);
let sw_prod_shape = g.shape(sw_prod).clone();
let sw_minus_one = g.binary(BinaryOp::Sub, sw_prod, one, sw_prod_shape);
let twenty = scalar_const(g, 20.0);
let bias_shape = g.shape(sw_minus_one).clone();
let bias = g.binary(BinaryOp::Mul, twenty, sw_minus_one, bias_shape);
attnw = g.add(aw_scaled, bias);
} else {
let mask_5d = g.reshape(
nlist_mask,
vec![nf as i64, nloc as i64, 1, 1, nnei as i64],
Shape::new(&[nf, nloc, 1, 1, nnei], DType::F32),
);
let one = scalar_const(g, 1.0);
let inv = g.sub(one, mask_5d);
let neg_inf = scalar_const(g, -1e30);
let add = g.mul(inv, neg_inf);
attnw = g.add(attnw, add);
}
let aw_shape = g.shape(attnw).clone();
attnw = g.softmax(attnw, -1, aw_shape);
let mask_k_5d = g.reshape(
nlist_mask,
vec![nf as i64, nloc as i64, 1, 1, nnei as i64],
Shape::new(&[nf, nloc, 1, 1, nnei], DType::F32),
);
attnw = g.mul(attnw, mask_k_5d);
let mask_q_5d = g.reshape(
nlist_mask,
vec![nf as i64, nloc as i64, 1, nnei as i64, 1],
Shape::new(&[nf, nloc, 1, nnei, 1], DType::F32),
);
attnw = g.mul(attnw, mask_q_5d);
if cfg.smooth {
let sw_b = g.reshape(
sw,
vec![nf as i64, nloc as i64, 1, 1, nnei as i64],
Shape::new(&[nf, nloc, 1, 1, nnei], DType::F32),
);
let sw_a = g.reshape(
sw,
vec![nf as i64, nloc as i64, 1, nnei as i64, 1],
Shape::new(&[nf, nloc, 1, nnei, 1], DType::F32),
);
attnw = g.mul(attnw, sw_a);
attnw = g.mul(attnw, sw_b);
}
let h2_t = g.transpose_(h2, vec![0, 1, 3, 2]);
let h2h2 = g.mm(h2, h2_t);
let inv_sqrt3 = scalar_const(g, 1.0 / 3f32.sqrt());
let h2h2 = g.mul(h2h2, inv_sqrt3);
let h2h2_5d_shape = Shape::new(&[nf, nloc, 1, nnei, nnei], DType::F32);
let h2h2_5d = g.reshape(
h2h2,
vec![nf as i64, nloc as i64, 1, nnei as i64, nnei as i64],
h2h2_5d_shape,
);
let aag = g.mul(attnw, h2h2_5d); Ok(g.transpose_(aag, vec![0, 1, 3, 4, 2]))
}
fn build_atten2_mh_apply(
g: &mut Graph,
cfg: &RepformersConfig,
prefix: &str,
aag: NodeId, g2: NodeId, nf: usize,
nloc: usize,
nnei: usize,
ng2: usize,
) -> NodeId {
let nh = cfg.attn2_nhead;
let mapv_w = g.param(
format!("{prefix}.mapv.w"),
Shape::new(&[ng2, nh * ng2], DType::F32),
);
let v = g.mm(g2, mapv_w); let v_5d_shape = Shape::new(&[nf, nloc, nnei, ng2, nh], DType::F32);
let v_5d = g.reshape(
v,
vec![nf as i64, nloc as i64, nnei as i64, ng2 as i64, nh as i64],
v_5d_shape,
);
let v_perm = g.transpose_(v_5d, vec![0, 1, 4, 2, 3]); let aag_perm = g.transpose_(aag, vec![0, 1, 4, 2, 3]); let out_5d = g.mm(aag_perm, v_perm); let out_perm = g.transpose_(out_5d, vec![0, 1, 3, 4, 2]); let out_flat_shape = Shape::new(&[nf, nloc, nnei, nh * ng2], DType::F32);
let out_flat = g.reshape(
out_perm,
vec![nf as i64, nloc as i64, nnei as i64, (nh * ng2) as i64],
out_flat_shape,
);
let head_w = g.param(
format!("{prefix}.head_map.w"),
Shape::new(&[nh * ng2, ng2], DType::F32),
);
let head_b = g.param(
format!("{prefix}.head_map.b"),
Shape::new(&[ng2], DType::F32),
);
let mm = g.mm(out_flat, head_w);
g.add(mm, head_b)
}
fn build_local_atten(
g: &mut Graph,
cfg: &RepformersConfig,
prefix: &str,
g1: NodeId, gg1: NodeId, nlist_mask: NodeId,
sw: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
ng1: usize,
) -> Result<NodeId> {
let nh = cfg.attn1_nhead;
let nd = cfg.attn1_hidden;
let mapq_w = g.param(
format!("{prefix}.mapq.w"),
Shape::new(&[ng1, nd * nh], DType::F32),
);
let g1q = g.mm(g1, mapq_w); let g1q_4d = g.reshape(
g1q,
vec![nf as i64, nloc as i64, nd as i64, nh as i64],
Shape::new(&[nf, nloc, nd, nh], DType::F32),
);
let g1q_perm = g.transpose_(g1q_4d, vec![0, 1, 3, 2]);
let mapkv_w = g.param(
format!("{prefix}.mapkv.w"),
Shape::new(&[ng1, (nd + ng1) * nh], DType::F32),
);
let gg1kv = g.mm(gg1, mapkv_w); let gg1kv_5d = g.reshape(
gg1kv,
vec![nf as i64, nloc as i64, nnei as i64, (nd + ng1) as i64, nh as i64],
Shape::new(&[nf, nloc, nnei, nd + ng1, nh], DType::F32),
);
let gg1kv_perm = g.transpose_(gg1kv_5d, vec![0, 1, 4, 2, 3]);
let gg1k = g.narrow_(gg1kv_perm, 4, 0, nd); let gg1v = g.narrow_(gg1kv_perm, 4, nd, ng1);
let g1q_5d = g.reshape(
g1q_perm,
vec![nf as i64, nloc as i64, nh as i64, 1, nd as i64],
Shape::new(&[nf, nloc, nh, 1, nd], DType::F32),
);
let gg1k_t = g.transpose_(gg1k, vec![0, 1, 2, 4, 3]); let mut attnw = g.mm(g1q_5d, gg1k_t); let scale = scalar_const(g, 1.0 / (nd as f32).sqrt());
attnw = g.mul(attnw, scale);
let attnw_4d_shape = Shape::new(&[nf, nloc, nh, nnei], DType::F32);
let mut attnw = g.reshape(
attnw,
vec![nf as i64, nloc as i64, nh as i64, nnei as i64],
attnw_4d_shape,
);
if cfg.smooth {
let sw_4d = g.reshape(
sw,
vec![nf as i64, nloc as i64, 1, nnei as i64],
Shape::new(&[nf, nloc, 1, nnei], DType::F32),
);
let aw_scaled = g.mul(attnw, sw_4d);
let one = scalar_const(g, 1.0);
let sw_shape = g.shape(sw_4d).clone();
let sw_minus_one = g.binary(BinaryOp::Sub, sw_4d, one, sw_shape);
let twenty = scalar_const(g, 20.0);
let bias_shape = g.shape(sw_minus_one).clone();
let bias = g.binary(BinaryOp::Mul, twenty, sw_minus_one, bias_shape);
attnw = g.add(aw_scaled, bias);
} else {
let mask_4d = g.reshape(
nlist_mask,
vec![nf as i64, nloc as i64, 1, nnei as i64],
Shape::new(&[nf, nloc, 1, nnei], DType::F32),
);
let one = scalar_const(g, 1.0);
let inv = g.sub(one, mask_4d);
let neg_inf = scalar_const(g, -1e30);
let add = g.mul(inv, neg_inf);
attnw = g.add(attnw, add);
}
let aw_shape = g.shape(attnw).clone();
attnw = g.softmax(attnw, -1, aw_shape);
let mask_4d = g.reshape(
nlist_mask,
vec![nf as i64, nloc as i64, 1, nnei as i64],
Shape::new(&[nf, nloc, 1, nnei], DType::F32),
);
attnw = g.mul(attnw, mask_4d);
if cfg.smooth {
let sw_4d = g.reshape(
sw,
vec![nf as i64, nloc as i64, 1, nnei as i64],
Shape::new(&[nf, nloc, 1, nnei], DType::F32),
);
attnw = g.mul(attnw, sw_4d);
}
let attnw_5d = g.reshape(
attnw,
vec![nf as i64, nloc as i64, nh as i64, 1, nnei as i64],
Shape::new(&[nf, nloc, nh, 1, nnei], DType::F32),
);
let ret = g.mm(attnw_5d, gg1v); let ret_flat_shape = Shape::new(&[nf, nloc, nh * ng1], DType::F32);
let ret_flat = g.reshape(
ret,
vec![nf as i64, nloc as i64, (nh * ng1) as i64],
ret_flat_shape,
);
let head_w = g.param(
format!("{prefix}.head_map.w"),
Shape::new(&[nh * ng1, ng1], DType::F32),
);
let head_b = g.param(
format!("{prefix}.head_map.b"),
Shape::new(&[ng1], DType::F32),
);
let mm = g.mm(ret_flat, head_w);
Ok(g.add(mm, head_b))
}
#[allow(dead_code)]
fn _keep_act_path(_a: Activation, _b: BinaryOp) {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn repformers_builds_default() {
let cfg = RepformersConfig {
rcut: 6.0,
rcut_smth: 0.5,
sel: 16,
ntypes: 2,
nlayers: 1,
g1_dim: 16,
g2_dim: 8,
axis_neuron: 2,
update_g1_has_conv: true,
update_g1_has_drrd: true,
update_g1_has_grrg: true,
update_g1_has_attn: true,
update_g2_has_g1g1: true,
update_g2_has_attn: true,
update_h2: false,
attn1_hidden: 8,
attn1_nhead: 2,
attn2_hidden: 4,
attn2_nhead: 2,
attn2_has_gate: false,
activation_function: "tanh".into(),
update_style: "res_avg".into(),
smooth: true,
use_sqrt_nnei: true,
g1_out_conv: true,
g1_out_mlp: true,
ln_eps: 1e-5,
epsilon: 1e-4,
};
let mut g = Graph::new("repformers");
let nf = 1;
let nloc = 2;
let nall = nloc;
let nnei = cfg.sel;
let g1_ext = g.input("g1_ext", Shape::new(&[nf, nall, cfg.g1_dim], DType::F32));
let env_mat = g.input("em", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let nlist = g.input("nl", Shape::new(&[nf, nloc, nnei], DType::I32));
let nlist_mask = g.input("nm", Shape::new(&[nf, nloc, nnei], DType::F32));
let sw = g.input("sw", Shape::new(&[nf, nloc, nnei], DType::F32));
let atype = g.input("at", Shape::new(&[nf, nloc], DType::I32));
let out = build_repformers(
&mut g,
&cfg,
RepformersInputs {
g1_ext,
env_mat_raw: env_mat,
nlist,
nlist_mask,
sw,
atype_loc: atype,
mapping: None,
},
nf,
nloc,
nall,
)
.expect("build");
let s = g.shape(out.g1);
assert_eq!(s.dim(2), rlx_ir::Dim::Static(cfg.g1_dim));
}
}