use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};
use crate::nn::{embedding_mlp, scalar_const, ActivationKind, MlpSpec};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SeRConfig {
pub rcut: f64,
pub rcut_smth: f64,
pub sel: Vec<usize>,
#[serde(default = "default_embedding_neuron")]
pub neuron: Vec<usize>,
#[serde(default)]
pub resnet_dt: bool,
#[serde(default = "default_activation")]
pub activation_function: String,
#[serde(default)]
pub type_map: Option<Vec<String>>,
}
fn default_embedding_neuron() -> Vec<usize> {
vec![24, 48, 96]
}
fn default_activation() -> String {
"tanh".into()
}
impl SeRConfig {
pub fn ntypes(&self) -> usize {
self.sel.len()
}
pub fn nnei(&self) -> usize {
self.sel.iter().sum()
}
pub fn ng(&self) -> usize {
*self.neuron.last().expect("se_r: empty neuron list")
}
pub fn dim_out(&self) -> usize {
self.ng()
}
pub fn sel_cumsum(&self) -> Vec<usize> {
let mut acc = Vec::with_capacity(self.sel.len() + 1);
acc.push(0);
let mut s = 0;
for &n in &self.sel {
s += n;
acc.push(s);
}
acc
}
}
pub struct SeRDescriptor {
pub descriptor: NodeId,
pub dim_out: usize,
}
pub fn build_se_r_descriptor(
g: &mut Graph,
cfg: &SeRConfig,
env_mat_raw: NodeId,
atype_loc: NodeId,
nf: usize,
nloc: usize,
exclude_mask: Option<NodeId>,
) -> Result<SeRDescriptor> {
let activation = ActivationKind::parse(&cfg.activation_function)?;
let ntypes = cfg.ntypes();
let nnei = cfg.nnei();
let ng = cfg.ng();
let sec = cfg.sel_cumsum();
let rr_shape = g.shape(env_mat_raw).clone();
if rr_shape.rank() != 4 {
bail!(
"se_r: env-matrix input must have rank 4 [nf, nloc, nnei, 1], got rank {}",
rr_shape.rank()
);
}
let davg = g.param(
"descriptor.davg",
Shape::new(&[ntypes, nnei, 1], DType::F32),
);
let dstd = g.param(
"descriptor.dstd",
Shape::new(&[ntypes, nnei, 1], DType::F32),
);
let davg_g = g.gather_(davg, atype_loc, 0);
let dstd_g = g.gather_(dstd, atype_loc, 0);
let mut rr = g.sub(env_mat_raw, davg_g);
rr = g.div(rr, dstd_g);
if let Some(mask) = exclude_mask {
let mask_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let mask_4d = g.reshape(
mask,
vec![nf as i64, nloc as i64, nnei as i64, 1],
mask_shape,
);
rr = g.mul(rr, mask_4d);
}
let inv_nnei = 1.0 / nnei as f32;
let mut acc: Option<NodeId> = None;
for t in 0..ntypes {
let start = sec[t];
let len = sec[t + 1] - start;
if len == 0 {
continue;
}
let rr_t = g.narrow_(rr, 2, start, len); let prefix = format!("descriptor.embedding.{t}");
let mlp = MlpSpec {
param_prefix: &prefix,
in_dim: 1,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let gg_t = embedding_mlp(g, &mlp, rr_t); let mean_shape = Shape::new(&[nf, nloc, ng], DType::F32);
let sum_g = g.reduce(gg_t, ReduceOp::Sum, vec![2], false, mean_shape);
let scale = scalar_const(g, inv_nnei);
let scaled = g.mul(sum_g, scale);
acc = Some(match acc {
None => scaled,
Some(prev) => g.add(prev, scaled),
});
}
let xyz_scatter =
acc.ok_or_else(|| anyhow::anyhow!("se_r: empty selection (sum(sel) == 0)"))?;
let five_inv = scalar_const(g, 0.2);
let res = g.mul(xyz_scatter, five_inv);
Ok(SeRDescriptor {
descriptor: res,
dim_out: ng,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn se_r_descriptor_builds() {
let cfg = SeRConfig {
rcut: 6.0,
rcut_smth: 0.5,
sel: vec![46, 92],
neuron: vec![16, 32, 64],
resnet_dt: false,
activation_function: "tanh".into(),
type_map: None,
};
let mut g = Graph::new("se_r_test");
let nf = 1;
let nloc = 8;
let nnei = cfg.nnei();
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 1], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let out = build_se_r_descriptor(&mut g, &cfg, env_mat, atype, nf, nloc, None)
.expect("build");
assert_eq!(out.dim_out, cfg.dim_out());
assert!(g.len() > 10);
}
}