use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};
use crate::nn::{embedding_mlp, scalar_const, ActivationKind, MlpSpec};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum TebdInputMode {
Concat,
Strip,
}
impl Default for TebdInputMode {
fn default() -> Self {
Self::Concat
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SeTTebdConfig {
pub rcut: f64,
pub rcut_smth: f64,
pub sel: Vec<usize>,
pub ntypes: usize,
#[serde(default = "default_neuron")]
pub neuron: Vec<usize>,
#[serde(default = "default_tebd_dim")]
pub tebd_dim: usize,
#[serde(default)]
pub tebd_input_mode: TebdInputMode,
#[serde(default)]
pub resnet_dt: bool,
#[serde(default = "default_activation")]
pub activation_function: String,
#[serde(default = "default_true")]
pub smooth: bool,
#[serde(default = "default_param_prefix")]
pub param_prefix: String,
}
fn default_param_prefix() -> String {
"descriptor".into()
}
fn default_neuron() -> Vec<usize> {
vec![25, 50, 100]
}
fn default_tebd_dim() -> usize {
8
}
fn default_activation() -> String {
"tanh".into()
}
fn default_true() -> bool {
true
}
impl SeTTebdConfig {
pub fn nnei(&self) -> usize {
self.sel.iter().sum()
}
pub fn ng(&self) -> usize {
*self.neuron.last().expect("se_t_tebd: empty neuron list")
}
pub fn dim_out(&self) -> usize {
self.ng()
}
}
pub struct SeTTebdDescriptor {
pub descriptor: NodeId,
pub dim_out: usize,
}
pub struct SeTTebdInputs {
pub env_mat_raw: NodeId,
pub atype_loc: NodeId,
pub nei_atype: NodeId,
pub type_embedding: NodeId,
pub sw: Option<NodeId>,
pub exclude_mask: Option<NodeId>,
}
pub fn build_se_t_tebd_descriptor(
g: &mut Graph,
cfg: &SeTTebdConfig,
inputs: SeTTebdInputs,
nf: usize,
nloc: usize,
) -> Result<SeTTebdDescriptor> {
let activation = ActivationKind::parse(&cfg.activation_function)?;
let nnei = cfg.nnei();
let ng = cfg.ng();
let ntypes = cfg.ntypes;
let tebd = cfg.tebd_dim;
let rr_shape = g.shape(inputs.env_mat_raw).clone();
if rr_shape.rank() != 4 {
bail!("se_t_tebd: env-matrix input must have rank 4");
}
let davg_name = format!("{}.davg", cfg.param_prefix);
let dstd_name = format!("{}.dstd", cfg.param_prefix);
let davg = g.param(
&davg_name,
Shape::new(&[ntypes, nnei, 4], DType::F32),
);
let dstd = g.param(
&dstd_name,
Shape::new(&[ntypes, nnei, 4], DType::F32),
);
let davg_g = g.gather_(davg, inputs.atype_loc, 0);
let dstd_g = g.gather_(dstd, inputs.atype_loc, 0);
let mut rr = g.sub(inputs.env_mat_raw, davg_g);
rr = g.div(rr, dstd_g);
if let Some(mask) = inputs.exclude_mask {
let mask_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let mask_4d = g.reshape(
mask,
vec![nf as i64, nloc as i64, nnei as i64, 1],
mask_shape,
);
rr = g.mul(rr, mask_4d);
}
let rr_xyz = g.narrow_(rr, 3, 1, 3); let rr_t = g.transpose_(rr_xyz, vec![0, 1, 3, 2]);
let env_ij = g.mm(rr_xyz, rr_t);
let env_5d_shape = Shape::new(&[nf, nloc, nnei, nnei, 1], DType::F32);
let env_5d = g.reshape(
env_ij,
vec![nf as i64, nloc as i64, nnei as i64, nnei as i64, 1],
env_5d_shape,
);
let t_nei = g.gather_(inputs.type_embedding, inputs.nei_atype, 0);
let gg = match cfg.tebd_input_mode {
TebdInputMode::Concat => {
let t_i_shape = Shape::new(&[nf, nloc, nnei, 1, tebd], DType::F32);
let t_i = g.reshape(
t_nei,
vec![nf as i64, nloc as i64, nnei as i64, 1, tebd as i64],
t_i_shape,
);
let t_j_shape = Shape::new(&[nf, nloc, 1, nnei, tebd], DType::F32);
let t_j = g.reshape(
t_nei,
vec![nf as i64, nloc as i64, 1, nnei as i64, tebd as i64],
t_j_shape,
);
let zero_i_shape = Shape::new(&[nf, nloc, nnei, nnei, tebd], DType::F32);
let zero_i = zero_tensor(g, &zero_i_shape);
let t_i_b = g.add(t_i, zero_i);
let zero_j_shape = Shape::new(&[nf, nloc, nnei, nnei, tebd], DType::F32);
let zero_j = zero_tensor(g, &zero_j_shape);
let t_j_b = g.add(t_j, zero_j);
let concat_shape =
Shape::new(&[nf, nloc, nnei, nnei, 1 + 2 * tebd], DType::F32);
let ss_concat = g.concat(vec![env_5d, t_i_b, t_j_b], 4, concat_shape);
let embedding_prefix = format!("{}.embedding", cfg.param_prefix);
let mlp = MlpSpec {
param_prefix: &embedding_prefix,
in_dim: 1 + 2 * tebd,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
embedding_mlp(g, &mlp, ss_concat) }
TebdInputMode::Strip => {
let embedding_prefix = format!("{}.embedding", cfg.param_prefix);
let mlp_main = MlpSpec {
param_prefix: &embedding_prefix,
in_dim: 1,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let gg_s = embedding_mlp(g, &mlp_main, env_5d);
let rows = ntypes_with_pad_from_shape(g, inputs.type_embedding);
let pair_in_shape = Shape::new(&[rows * rows, 2 * tebd], DType::F32);
let tt_pair = build_pair_table(g, inputs.type_embedding, rows, tebd)?;
let strip_prefix = format!("{}.embedding_strip", cfg.param_prefix);
let mlp_strip = MlpSpec {
param_prefix: &strip_prefix,
in_dim: 2 * tebd,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let _ = pair_in_shape;
let tt_full = embedding_mlp(g, &mlp_strip, tt_pair);
let rows_const = scalar_i32_const(g, rows as i32);
let ti_shape = Shape::new(&[nf, nloc, nnei, 1], DType::I32);
let nei_i_4d = g.reshape(
inputs.nei_atype,
vec![nf as i64, nloc as i64, nnei as i64, 1],
ti_shape,
);
let tj_shape = Shape::new(&[nf, nloc, 1, nnei], DType::I32);
let nei_j_4d = g.reshape(
inputs.nei_atype,
vec![nf as i64, nloc as i64, 1, nnei as i64],
tj_shape,
);
let i_scaled = g.binary(
rlx_ir::op::BinaryOp::Mul,
nei_i_4d,
rows_const,
g.shape(nei_i_4d).clone(),
);
let pair_idx_shape = Shape::new(&[nf, nloc, nnei, nnei], DType::I32);
let pair_idx_4d =
g.binary(rlx_ir::op::BinaryOp::Add, i_scaled, nei_j_4d, pair_idx_shape);
let flat_total = nf * nloc * nnei * nnei;
let flat_shape = Shape::new(&[flat_total], DType::I32);
let pair_idx_flat = g.reshape(pair_idx_4d, vec![flat_total as i64], flat_shape);
let gg_t_flat = g.gather_(tt_full, pair_idx_flat, 0); let gg_t_shape = Shape::new(&[nf, nloc, nnei, nnei, ng], DType::F32);
let gg_t = g.reshape(
gg_t_flat,
vec![
nf as i64,
nloc as i64,
nnei as i64,
nnei as i64,
ng as i64,
],
gg_t_shape,
);
let mut gg_t = gg_t;
if cfg.smooth {
if let Some(sw) = inputs.sw {
let sw_i_shape = Shape::new(&[nf, nloc, nnei, 1, 1], DType::F32);
let sw_i = g.reshape(
sw,
vec![nf as i64, nloc as i64, nnei as i64, 1, 1],
sw_i_shape,
);
let sw_j_shape = Shape::new(&[nf, nloc, 1, nnei, 1], DType::F32);
let sw_j = g.reshape(
sw,
vec![nf as i64, nloc as i64, 1, nnei as i64, 1],
sw_j_shape,
);
gg_t = g.mul(gg_t, sw_i);
gg_t = g.mul(gg_t, sw_j);
}
}
let prod = g.mul(gg_s, gg_t);
g.add(prod, gg_s)
}
};
let weighted = g.mul(env_5d, gg);
let sum_shape = Shape::new(&[nf, nloc, ng], DType::F32);
let summed = g.reduce(weighted, ReduceOp::Sum, vec![2, 3], false, sum_shape);
let inv_nnei2 = scalar_const(g, 1.0 / (nnei as f32 * nnei as f32));
let res = g.mul(summed, inv_nnei2);
Ok(SeTTebdDescriptor {
descriptor: res,
dim_out: ng,
})
}
fn zero_tensor(g: &mut Graph, shape: &Shape) -> NodeId {
let n: usize = shape
.dims()
.iter()
.map(|d| match d {
rlx_ir::Dim::Static(n) => *n,
_ => 0,
})
.product();
let bytes = vec![0u8; n * 4];
g.add_node(
rlx_ir::op::Op::Constant { data: bytes },
vec![],
shape.clone(),
)
}
fn scalar_i32_const(g: &mut Graph, v: i32) -> NodeId {
let bytes = v.to_le_bytes().to_vec();
g.add_node(
rlx_ir::op::Op::Constant { data: bytes },
vec![],
Shape::new(&[1], DType::I32),
)
}
fn ntypes_with_pad_from_shape(g: &Graph, table: NodeId) -> usize {
match g.shape(table).dim(0) {
rlx_ir::Dim::Static(n) => n,
_ => panic!("type embedding table must have static row count"),
}
}
fn build_pair_table(
g: &mut Graph,
table: NodeId,
rows: usize,
tebd: usize,
) -> Result<NodeId> {
let mut i_data = Vec::with_capacity(rows * rows);
let mut j_data = Vec::with_capacity(rows * rows);
for i in 0..rows {
for j in 0..rows {
i_data.push(i as i32);
j_data.push(j as i32);
}
}
let i_bytes: Vec<u8> = i_data.iter().flat_map(|v| v.to_le_bytes()).collect();
let j_bytes: Vec<u8> = j_data.iter().flat_map(|v| v.to_le_bytes()).collect();
let i_idx = g.add_node(
rlx_ir::op::Op::Constant { data: i_bytes },
vec![],
Shape::new(&[rows * rows], DType::I32),
);
let j_idx = g.add_node(
rlx_ir::op::Op::Constant { data: j_bytes },
vec![],
Shape::new(&[rows * rows], DType::I32),
);
let t_i = g.gather_(table, i_idx, 0); let t_j = g.gather_(table, j_idx, 0); let out_shape = Shape::new(&[rows * rows, 2 * tebd], DType::F32);
Ok(g.concat(vec![t_i, t_j], 1, out_shape))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn se_t_tebd_concat_builds() {
let cfg = SeTTebdConfig {
rcut: 6.0,
rcut_smth: 0.5,
sel: vec![20, 20],
ntypes: 2,
neuron: vec![16, 32],
tebd_dim: 8,
tebd_input_mode: TebdInputMode::Concat,
resnet_dt: false,
activation_function: "tanh".into(),
smooth: false,
param_prefix: "descriptor".into(),
};
let mut g = Graph::new("se_t_tebd_concat");
let nf = 1;
let nloc = 4;
let nnei = cfg.nnei();
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let nei_atype = g.input("nei_atype", Shape::new(&[nf, nloc, nnei], DType::I32));
let t_table = g.param(
"type_embed.table",
Shape::new(&[cfg.ntypes + 1, cfg.tebd_dim], DType::F32),
);
let inputs = SeTTebdInputs {
env_mat_raw: env_mat,
atype_loc: atype,
nei_atype,
type_embedding: t_table,
sw: None,
exclude_mask: None,
};
let out = build_se_t_tebd_descriptor(&mut g, &cfg, inputs, nf, nloc)
.expect("build");
assert_eq!(out.dim_out, cfg.dim_out());
assert!(g.len() > 20);
}
#[test]
fn se_t_tebd_strip_builds() {
let cfg = SeTTebdConfig {
rcut: 6.0,
rcut_smth: 0.5,
sel: vec![10, 10],
ntypes: 2,
neuron: vec![8, 16],
tebd_dim: 4,
tebd_input_mode: TebdInputMode::Strip,
resnet_dt: false,
activation_function: "tanh".into(),
smooth: true,
param_prefix: "descriptor".into(),
};
let mut g = Graph::new("se_t_tebd_strip");
let nf = 1;
let nloc = 2;
let nnei = cfg.nnei();
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let nei_atype = g.input("nei_atype", Shape::new(&[nf, nloc, nnei], DType::I32));
let sw = g.input("sw", Shape::new(&[nf, nloc, nnei], DType::F32));
let t_table = g.param(
"type_embed.table",
Shape::new(&[cfg.ntypes + 1, cfg.tebd_dim], DType::F32),
);
let inputs = SeTTebdInputs {
env_mat_raw: env_mat,
atype_loc: atype,
nei_atype,
type_embedding: t_table,
sw: Some(sw),
exclude_mask: None,
};
let out = build_se_t_tebd_descriptor(&mut g, &cfg, inputs, nf, nloc)
.expect("build");
assert_eq!(out.dim_out, cfg.dim_out());
assert!(g.len() > 30);
}
}