use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};
use crate::nn::{scalar_const, ActivationKind};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RepflowsConfig {
pub rcut: f64,
pub rcut_smth: f64,
pub nsel: usize,
pub a_rcut: f64,
pub a_rcut_smth: f64,
pub a_sel: usize,
pub ntypes: usize,
#[serde(default = "default_nlayers")]
pub nlayers: usize,
#[serde(default = "default_n_dim")]
pub n_dim: usize,
#[serde(default = "default_e_dim")]
pub e_dim: usize,
#[serde(default = "default_a_dim")]
pub a_dim: usize,
#[serde(default = "default_activation")]
pub activation_function: String,
#[serde(default = "default_true")]
pub smooth: bool,
#[serde(default = "default_ln_eps")]
pub ln_eps: f32,
#[serde(default = "default_axis_neuron")]
pub axis_neuron: usize,
#[serde(default = "default_true")]
pub update_angle: bool,
}
fn default_axis_neuron() -> usize {
4
}
impl RepflowsConfig {
pub fn axis_neuron(&self) -> usize {
self.axis_neuron
}
}
fn default_nlayers() -> usize {
6
}
fn default_n_dim() -> usize {
128
}
fn default_e_dim() -> usize {
16
}
fn default_a_dim() -> usize {
8
}
fn default_activation() -> String {
"silu".into()
}
fn default_true() -> bool {
true
}
fn default_ln_eps() -> f32 {
1e-5
}
pub struct RepflowsInputs {
pub node_ebd_ext: NodeId,
pub env_mat_raw: NodeId,
pub nlist: NodeId,
pub nlist_mask: NodeId,
pub a_sw: Option<NodeId>,
pub sw: NodeId,
pub angle_input: NodeId,
pub angle_mask: NodeId,
pub atype_loc: NodeId,
}
pub struct RepflowsOutputs {
pub node_ebd: NodeId,
pub edge_ebd: NodeId,
pub h2: NodeId,
pub rot_mat: NodeId,
pub sw: NodeId,
}
pub fn build_repflows(
g: &mut Graph,
cfg: &RepflowsConfig,
inputs: RepflowsInputs,
nf: usize,
nloc: usize,
nall: usize,
) -> Result<RepflowsOutputs> {
if !cfg.smooth {
bail!("repflows: only smooth=true is implemented");
}
let activation = ActivationKind::parse(&cfg.activation_function)?;
let n_dim = cfg.n_dim;
let e_dim = cfg.e_dim;
let a_dim = cfg.a_dim;
let nnei = cfg.nsel;
let a_sel = cfg.a_sel;
let davg = g.param(
"repflows.davg",
Shape::new(&[cfg.ntypes, nnei, 4], DType::F32),
);
let dstd = g.param(
"repflows.dstd",
Shape::new(&[cfg.ntypes, nnei, 4], DType::F32),
);
let davg_g = g.gather_(davg, inputs.atype_loc, 0);
let dstd_g = g.gather_(dstd, inputs.atype_loc, 0);
let mut dmatrix = g.sub(inputs.env_mat_raw, davg_g);
dmatrix = g.div(dmatrix, dstd_g);
let node_ebd_ext_act = crate::nn::apply_activation(g, activation, inputs.node_ebd_ext);
let mut node_ebd = g.narrow_(node_ebd_ext_act, 1, 0, nloc); let mut node_ebd_ext_cur = node_ebd_ext_act;
let edge_input = g.narrow_(dmatrix, 3, 0, 1); let edge_w = g.param(
"repflows.edge_embd.w",
Shape::new(&[1, e_dim], DType::F32),
);
let edge_b = g.param(
"repflows.edge_embd.b",
Shape::new(&[e_dim], DType::F32),
);
let edge_pre = g.mm(edge_input, edge_w);
let edge_pre = g.add(edge_pre, edge_b);
let mut edge_ebd = crate::nn::apply_activation(g, activation, edge_pre);
let mut h2 = g.narrow_(dmatrix, 3, 1, 3);
let ang_w = g.param(
"repflows.angle_embd.w",
Shape::new(&[1, a_dim], DType::F32),
);
let ang_b = g.param(
"repflows.angle_embd.b",
Shape::new(&[a_dim], DType::F32),
);
let ang_pre = g.mm(inputs.angle_input, ang_w);
let mut angle_ebd = g.add(ang_pre, ang_b);
let a_sw_in = inputs.a_sw.unwrap_or_else(|| g.narrow_(inputs.sw, 2, 0, a_sel));
for l in 0..cfg.nlayers {
let prefix = format!("repflows.layer{l}");
let (n, e, h, a) = build_repflow_layer(
g, cfg, &prefix, activation, node_ebd, edge_ebd, h2, angle_ebd,
inputs.nlist, inputs.nlist_mask, inputs.sw, inputs.angle_mask,
a_sw_in,
node_ebd_ext_cur, nf, nloc, nnei, a_sel, n_dim, e_dim, a_dim, nall,
)?;
node_ebd = n;
edge_ebd = e;
h2 = h;
angle_ebd = a;
if l + 1 < cfg.nlayers && nall == nloc {
node_ebd_ext_cur = node_ebd;
}
}
let edge_t = g.transpose_(edge_ebd, vec![0, 1, 3, 2]); let rot_pre = g.mm(edge_t, h2); let inv = scalar_const(g, 1.0 / (nnei as f32).sqrt());
let rot_mat = g.mul(rot_pre, inv);
let _ = (a_dim, a_sel);
Ok(RepflowsOutputs {
node_ebd,
edge_ebd,
h2,
rot_mat,
sw: inputs.sw,
})
}
fn build_repflow_layer(
g: &mut Graph,
cfg: &RepflowsConfig,
prefix: &str,
activation: ActivationKind,
node_ebd: NodeId,
edge_ebd: NodeId,
h2: NodeId,
angle_ebd: NodeId,
nlist: NodeId,
_nlist_mask: NodeId,
sw: NodeId,
angle_mask: NodeId,
a_sw: NodeId,
node_ebd_ext: NodeId,
nf: usize,
nloc: usize,
nnei: usize,
a_sel: usize,
n_dim: usize,
e_dim: usize,
a_dim: usize,
_nall: usize,
) -> Result<(NodeId, NodeId, NodeId, NodeId)> {
let axis = cfg.axis_neuron();
let nlist_2d_shape = Shape::new(&[nf, nloc * nnei], DType::I32);
let nlist_2d = g.reshape(
nlist,
vec![nf as i64, (nloc * nnei) as i64],
nlist_2d_shape,
);
let gathered = g.gather_(node_ebd_ext, nlist_2d, 1);
let gg_shape = Shape::new(&[nf, nloc, nnei, n_dim], DType::F32);
let nei_node_ebd = g.reshape(
gathered,
vec![nf as i64, nloc as i64, nnei as i64, n_dim as i64],
gg_shape,
);
let sw_4d = g.reshape(
sw,
vec![nf as i64, nloc as i64, nnei as i64, 1],
Shape::new(&[nf, nloc, nnei, 1], DType::F32),
);
let node_self = linear_act(
g, &format!("{prefix}.node_self_mlp"), node_ebd, n_dim, n_dim, activation,
);
let sym_edge = symmetrization_op(g, edge_ebd, h2, sw, cfg, axis, nf, nloc, nnei, e_dim);
let sym_nei_node = symmetrization_op(g, nei_node_ebd, h2, sw, cfg, axis, nf, nloc, nnei, n_dim);
let n_sym_dim = e_dim * axis + n_dim * axis;
let sym_cat_shape = Shape::new(&[nf, nloc, n_sym_dim], DType::F32);
let sym_cat = g.concat(vec![sym_edge, sym_nei_node], 2, sym_cat_shape);
let node_sym = linear_act(
g, &format!("{prefix}.node_sym_linear"), sym_cat, n_sym_dim, n_dim, activation,
);
let edge_info_dim = 2 * n_dim + e_dim;
let node_4d_shape = Shape::new(&[nf, nloc, 1, n_dim], DType::F32);
let node_4d = g.reshape(
node_ebd,
vec![nf as i64, nloc as i64, 1, n_dim as i64],
node_4d_shape,
);
let node_tile = broadcast_along_axis(g, node_4d, 2, nnei, &[nf, nloc, nnei, n_dim]);
let edge_info_shape =
Shape::new(&[nf, nloc, nnei, edge_info_dim], DType::F32);
let edge_info =
g.concat(vec![node_tile, nei_node_ebd, edge_ebd], 3, edge_info_shape);
let nem_pre = linear_act(
g, &format!("{prefix}.node_edge_linear"), edge_info, edge_info_dim, n_dim, activation,
);
let nem_sw = g.mul(nem_pre, sw_4d);
let n_pool_shape = Shape::new(&[nf, nloc, n_dim], DType::F32);
let node_edge_msg =
g.reduce(nem_sw, ReduceOp::Sum, vec![2], false, n_pool_shape);
let inv_nnei = scalar_const(g, 1.0 / nnei as f32);
let node_edge_msg = g.mul(node_edge_msg, inv_nnei);
let node_new = res_avg(g, &[node_ebd, node_self, node_sym, node_edge_msg]);
let edge_self = linear_act(
g, &format!("{prefix}.edge_self_linear"), edge_info, edge_info_dim, e_dim, activation,
);
let _ = angle_mask;
let edge_for_angle = g.narrow_(edge_ebd, 2, 0, a_sel); let a_mask_4d = g.reshape(
angle_mask,
vec![nf as i64, nloc as i64, a_sel as i64, a_sel as i64],
Shape::new(&[nf, nloc, a_sel, a_sel], DType::F32),
);
let _ = a_mask_4d;
let edge_for_angle_4d = g.reshape(
edge_for_angle,
vec![nf as i64, nloc as i64, 1, a_sel as i64, e_dim as i64],
Shape::new(&[nf, nloc, 1, a_sel, e_dim], DType::F32),
);
let edge_for_angle_k =
broadcast_along_axis(g, edge_for_angle_4d, 2, a_sel, &[nf, nloc, a_sel, a_sel, e_dim]);
let edge_for_angle_4d_j = g.reshape(
edge_for_angle,
vec![nf as i64, nloc as i64, a_sel as i64, 1, e_dim as i64],
Shape::new(&[nf, nloc, a_sel, 1, e_dim], DType::F32),
);
let edge_for_angle_j =
broadcast_along_axis(g, edge_for_angle_4d_j, 3, a_sel, &[nf, nloc, a_sel, a_sel, e_dim]);
let edge_for_angle_info_shape =
Shape::new(&[nf, nloc, a_sel, a_sel, 2 * e_dim], DType::F32);
let edge_for_angle_info = g.concat(
vec![edge_for_angle_k, edge_for_angle_j],
4,
edge_for_angle_info_shape,
);
let node_5d_shape = Shape::new(&[nf, nloc, 1, 1, n_dim], DType::F32);
let node_5d = g.reshape(
node_ebd,
vec![nf as i64, nloc as i64, 1, 1, n_dim as i64],
node_5d_shape,
);
let node_for_angle_info_inner =
broadcast_along_axis(g, node_5d, 2, a_sel, &[nf, nloc, a_sel, 1, n_dim]);
let node_for_angle_info = broadcast_along_axis(
g,
node_for_angle_info_inner,
3,
a_sel,
&[nf, nloc, a_sel, a_sel, n_dim],
);
let angle_info_dim = a_dim + n_dim + 2 * e_dim;
let angle_info_shape =
Shape::new(&[nf, nloc, a_sel, a_sel, angle_info_dim], DType::F32);
let angle_info = g.concat(
vec![angle_ebd, node_for_angle_info, edge_for_angle_info],
4,
angle_info_shape,
);
let edge_angle_update = linear_act(
g,
&format!("{prefix}.edge_angle_linear1"),
angle_info,
angle_info_dim,
a_dim,
activation,
);
let a_sw_5d_a = g.reshape(
a_sw,
vec![nf as i64, nloc as i64, a_sel as i64, 1, 1],
Shape::new(&[nf, nloc, a_sel, 1, 1], DType::F32),
);
let a_sw_5d_b = g.reshape(
a_sw,
vec![nf as i64, nloc as i64, 1, a_sel as i64, 1],
Shape::new(&[nf, nloc, 1, a_sel, 1], DType::F32),
);
let a_sw_prod = g.mul(a_sw_5d_a, a_sw_5d_b);
let weighted = g.mul(a_sw_prod, edge_angle_update);
let reduced_shape = Shape::new(&[nf, nloc, a_sel, a_dim], DType::F32);
let reduced = g.reduce(weighted, ReduceOp::Sum, vec![3], false, reduced_shape);
let inv_sqrt_a_sel = scalar_const(g, 1.0 / (a_sel as f32).sqrt());
let reduced = g.mul(reduced, inv_sqrt_a_sel);
let padded_shape = Shape::new(&[nf, nloc, nnei, a_dim], DType::F32);
let pad_zeros = zero_f32(g, &[nf, nloc, nnei - a_sel, a_dim]);
let padded = g.concat(vec![reduced, pad_zeros], 2, padded_shape);
let edge_angle_reduction = linear_act(
g,
&format!("{prefix}.edge_angle_linear2"),
padded,
a_dim,
e_dim,
activation,
);
let edge_new = res_avg(g, &[edge_ebd, edge_self, edge_angle_reduction]);
let angle_self = linear_act(
g,
&format!("{prefix}.angle_self_linear"),
angle_info,
angle_info_dim,
a_dim,
activation,
);
let angle_new = res_avg(g, &[angle_ebd, angle_self]);
Ok((node_new, edge_new, h2, angle_new))
}
fn broadcast_along_axis(
g: &mut Graph,
x: NodeId,
axis: usize,
n: usize,
out_shape: &[usize],
) -> NodeId {
let mut tiles = Vec::with_capacity(n);
for _ in 0..n {
tiles.push(x);
}
g.concat(tiles, axis, Shape::new(out_shape, DType::F32))
}
fn symmetrization_op(
g: &mut Graph,
val: NodeId, h: NodeId, sw: NodeId,
_cfg: &RepflowsConfig,
axis: usize,
nf: usize,
nloc: usize,
nnei: usize,
ng: usize,
) -> NodeId {
let sw_4d = g.reshape(
sw,
vec![nf as i64, nloc as i64, nnei as i64, 1],
Shape::new(&[nf, nloc, nnei, 1], DType::F32),
);
let val_sw = g.mul(val, sw_4d);
let h_t = g.transpose_(h, vec![0, 1, 3, 2]); let hg = g.mm(h_t, val_sw); let inv = scalar_const(g, 1.0 / (nnei as f32).sqrt());
let hg = g.mul(hg, inv);
let hgm = g.narrow_(hg, 3, 0, axis); let hgm_t = g.transpose_(hgm, vec![0, 1, 3, 2]); let grrg = g.mm(hgm_t, hg); let inv_3 = scalar_const(g, 1.0 / 3.0);
let grrg = g.mul(grrg, inv_3);
g.reshape(
grrg,
vec![nf as i64, nloc as i64, (axis * ng) as i64],
Shape::new(&[nf, nloc, axis * ng], DType::F32),
)
}
fn linear_act(
g: &mut Graph,
prefix: &str,
x: NodeId,
in_dim: usize,
out_dim: usize,
activation: ActivationKind,
) -> NodeId {
let w = g.param(
format!("{prefix}.w"),
Shape::new(&[in_dim, out_dim], DType::F32),
);
let b = g.param(format!("{prefix}.b"), Shape::new(&[out_dim], DType::F32));
let mm = g.mm(x, w);
let pre = g.add(mm, b);
crate::nn::apply_activation(g, activation, pre)
}
fn res_avg(g: &mut Graph, list: &[NodeId]) -> NodeId {
let mut acc = list[0];
for &x in &list[1..] {
acc = g.add(acc, x);
}
let inv_sqrt = scalar_const(g, 1.0 / (list.len() as f32).sqrt());
g.mul(acc, inv_sqrt)
}
fn zero_f32(g: &mut Graph, shape: &[usize]) -> NodeId {
let n: usize = shape.iter().product();
g.add_node(
rlx_ir::op::Op::Constant {
data: vec![0u8; n * 4],
},
vec![],
Shape::new(shape, DType::F32),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn repflows_builds() {
let cfg = RepflowsConfig {
rcut: 6.0,
rcut_smth: 0.5,
nsel: 16,
a_rcut: 4.0,
a_rcut_smth: 0.5,
a_sel: 8,
ntypes: 2,
nlayers: 1,
n_dim: 16,
e_dim: 8,
a_dim: 4,
activation_function: "tanh".into(),
smooth: true,
ln_eps: 1e-5,
axis_neuron: 4,
update_angle: true,
};
let mut g = Graph::new("repflows");
let nf = 1;
let nloc = 2;
let nall = nloc;
let nnei = cfg.nsel;
let a = cfg.a_sel;
let node_ext = g.input(
"node_ext",
Shape::new(&[nf, nall, cfg.n_dim], DType::F32),
);
let em = g.input("em", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let nlist = g.input("nl", Shape::new(&[nf, nloc, nnei], DType::I32));
let nm = g.input("nm", Shape::new(&[nf, nloc, nnei], DType::F32));
let sw = g.input("sw", Shape::new(&[nf, nloc, nnei], DType::F32));
let ai = g.input("ai", Shape::new(&[nf, nloc, a, a, 1], DType::F32));
let am = g.input("am", Shape::new(&[nf, nloc, a, a], DType::F32));
let at = g.input("at", Shape::new(&[nf, nloc], DType::I32));
let out = build_repflows(
&mut g,
&cfg,
RepflowsInputs {
node_ebd_ext: node_ext,
env_mat_raw: em,
nlist,
nlist_mask: nm,
sw,
angle_input: ai,
angle_mask: am,
a_sw: None,
atype_loc: at,
},
nf,
nloc,
nall,
)
.expect("build");
assert_eq!(
g.shape(out.node_ebd).dim(2),
rlx_ir::Dim::Static(cfg.n_dim)
);
}
}