use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};
use crate::nn::{embedding_mlp, scalar_const, ActivationKind, MlpSpec};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SeTConfig {
pub rcut: f64,
pub rcut_smth: f64,
pub sel: Vec<usize>,
#[serde(default = "default_neuron")]
pub neuron: Vec<usize>,
#[serde(default)]
pub resnet_dt: bool,
#[serde(default = "default_activation")]
pub activation_function: String,
#[serde(default)]
pub type_map: Option<Vec<String>>,
}
fn default_neuron() -> Vec<usize> {
vec![24, 48, 96]
}
fn default_activation() -> String {
"tanh".into()
}
impl SeTConfig {
pub fn ntypes(&self) -> usize {
self.sel.len()
}
pub fn nnei(&self) -> usize {
self.sel.iter().sum()
}
pub fn ng(&self) -> usize {
*self.neuron.last().expect("se_t: empty neuron list")
}
pub fn dim_out(&self) -> usize {
self.ng()
}
pub fn sel_cumsum(&self) -> Vec<usize> {
let mut acc = Vec::with_capacity(self.sel.len() + 1);
acc.push(0);
let mut s = 0;
for &n in &self.sel {
s += n;
acc.push(s);
}
acc
}
}
pub struct SeTDescriptor {
pub descriptor: NodeId,
pub dim_out: usize,
}
pub fn build_se_t_descriptor(
g: &mut Graph,
cfg: &SeTConfig,
env_mat_raw: NodeId,
atype_loc: NodeId,
nf: usize,
nloc: usize,
exclude_mask: Option<NodeId>,
) -> Result<SeTDescriptor> {
let activation = ActivationKind::parse(&cfg.activation_function)?;
let ntypes = cfg.ntypes();
let nnei = cfg.nnei();
let ng = cfg.ng();
let sec = cfg.sel_cumsum();
let rr_shape = g.shape(env_mat_raw).clone();
if rr_shape.rank() != 4 {
bail!(
"se_t: env-matrix input must have rank 4 [nf, nloc, nnei, 4], got rank {}",
rr_shape.rank()
);
}
let davg = g.param(
"descriptor.davg",
Shape::new(&[ntypes, nnei, 4], DType::F32),
);
let dstd = g.param(
"descriptor.dstd",
Shape::new(&[ntypes, nnei, 4], DType::F32),
);
let davg_g = g.gather_(davg, atype_loc, 0);
let dstd_g = g.gather_(dstd, atype_loc, 0);
let mut rr = g.sub(env_mat_raw, davg_g);
rr = g.div(rr, dstd_g);
if let Some(mask) = exclude_mask {
let mask_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
let mask_4d = g.reshape(
mask,
vec![nf as i64, nloc as i64, nnei as i64, 1],
mask_shape,
);
rr = g.mul(rr, mask_4d);
}
let rr_xyz = g.narrow_(rr, 3, 1, 3);
let mut acc: Option<NodeId> = None;
for ti in 0..ntypes {
for tj in ti..ntypes {
let nt_i = sec[ti + 1] - sec[ti];
let nt_j = sec[tj + 1] - sec[tj];
if nt_i == 0 || nt_j == 0 {
continue;
}
let rr_i = g.narrow_(rr_xyz, 2, sec[ti], nt_i);
let rr_j = g.narrow_(rr_xyz, 2, sec[tj], nt_j);
let rr_j_t = g.transpose_(rr_j, vec![0, 1, 3, 2]);
let env_ij = g.mm(rr_i, rr_j_t);
let env_5d_shape = Shape::new(&[nf, nloc, nt_i, nt_j, 1], DType::F32);
let env_5d = g.reshape(
env_ij,
vec![nf as i64, nloc as i64, nt_i as i64, nt_j as i64, 1],
env_5d_shape,
);
let prefix = format!("descriptor.embedding.{ti}.{tj}");
let mlp = MlpSpec {
param_prefix: &prefix,
in_dim: 1,
neuron: &cfg.neuron,
activation,
resnet_dt: cfg.resnet_dt,
};
let gg = embedding_mlp(g, &mlp, env_5d);
let weighted = g.mul(env_5d, gg);
let sum_shape = Shape::new(&[nf, nloc, ng], DType::F32);
let summed = g.reduce(weighted, ReduceOp::Sum, vec![2, 3], false, sum_shape);
let scale = scalar_const(g, 1.0 / (nt_i as f32 * nt_j as f32));
let scaled = g.mul(summed, scale);
acc = Some(match acc {
None => scaled,
Some(prev) => g.add(prev, scaled),
});
}
}
let res = acc.ok_or_else(|| anyhow::anyhow!("se_t: empty selection"))?;
Ok(SeTDescriptor {
descriptor: res,
dim_out: ng,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn se_t_descriptor_builds() {
let cfg = SeTConfig {
rcut: 6.0,
rcut_smth: 0.5,
sel: vec![20, 40],
neuron: vec![8, 16, 32],
resnet_dt: false,
activation_function: "tanh".into(),
type_map: None,
};
let mut g = Graph::new("se_t_test");
let nf = 1;
let nloc = 4;
let nnei = cfg.nnei();
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let out = build_se_t_descriptor(&mut g, &cfg, env_mat, atype, nf, nloc, None)
.expect("build");
assert_eq!(out.dim_out, cfg.dim_out());
assert!(g.len() > 10);
}
}