use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};
use crate::config::EnerFittingConfig;
use crate::nn::{
embedding_mlp, scalar_const, ActivationKind, DenseLayerSpec, MlpSpec,
};
pub struct FittingInputs {
pub descriptor: NodeId,
pub atype_loc: NodeId,
pub fparam: Option<NodeId>,
pub aparam: Option<NodeId>,
pub case_embd: Option<NodeId>,
pub exclude_mask: Option<NodeId>,
pub gr: Option<NodeId>,
}
fn build_fitting_trunk(
g: &mut Graph,
cfg: &EnerFittingConfig,
inputs: &FittingInputs,
nf: usize,
nloc: usize,
net_out_dim: usize,
param_prefix: &str,
) -> Result<NodeId> {
let activation = ActivationKind::parse(&cfg.activation_function)?;
let mut x = inputs.descriptor;
let mut in_dim = cfg.dim_descrpt;
if cfg.numb_fparam > 0 {
let f = inputs.fparam.ok_or_else(|| {
anyhow::anyhow!("numb_fparam > 0 but no fparam node provided")
})?;
let avg = g.param(
format!("{param_prefix}.fparam_avg"),
Shape::new(&[cfg.numb_fparam], DType::F32),
);
let inv_std = g.param(
format!("{param_prefix}.fparam_inv_std"),
Shape::new(&[cfg.numb_fparam], DType::F32),
);
let normed = g.sub(f, avg);
let normed = g.mul(normed, inv_std);
x = concat_last(g, x, normed, nf, nloc, in_dim, cfg.numb_fparam);
in_dim += cfg.numb_fparam;
}
if cfg.numb_aparam > 0 {
let a = inputs.aparam.ok_or_else(|| {
anyhow::anyhow!("numb_aparam > 0 but no aparam node provided")
})?;
let avg = g.param(
format!("{param_prefix}.aparam_avg"),
Shape::new(&[cfg.numb_aparam], DType::F32),
);
let inv_std = g.param(
format!("{param_prefix}.aparam_inv_std"),
Shape::new(&[cfg.numb_aparam], DType::F32),
);
let normed = g.sub(a, avg);
let normed = g.mul(normed, inv_std);
x = concat_last(g, x, normed, nf, nloc, in_dim, cfg.numb_aparam);
in_dim += cfg.numb_aparam;
}
if cfg.dim_case_embd > 0 {
let c = inputs.case_embd.ok_or_else(|| {
anyhow::anyhow!("dim_case_embd > 0 but no case_embd node provided")
})?;
x = concat_last(g, x, c, nf, nloc, in_dim, cfg.dim_case_embd);
in_dim += cfg.dim_case_embd;
}
let out = if cfg.mixed_types {
fitting_mlp_path(g, &format!("{param_prefix}"), in_dim, &cfg.neuron,
net_out_dim, activation, cfg.resnet_dt, x)
} else {
let ntypes = cfg.ntypes;
let mut acc: Option<NodeId> = None;
for ti in 0..ntypes {
let prefix = format!("{param_prefix}.type_{ti}");
let y_ti = fitting_mlp_path(
g, &prefix, in_dim, &cfg.neuron, net_out_dim, activation,
cfg.resnet_dt, x,
);
let mask = atype_mask(g, inputs.atype_loc, ti as i32, nf, nloc);
let mask_3d = g.reshape(
mask,
vec![nf as i64, nloc as i64, 1],
Shape::new(&[nf, nloc, 1], DType::F32),
);
let masked = g.mul(y_ti, mask_3d);
acc = Some(match acc {
None => masked,
Some(prev) => g.add(prev, masked),
});
}
acc.ok_or_else(|| anyhow::anyhow!("fitting: ntypes == 0"))?
};
Ok(out)
}
fn fitting_mlp_path(
g: &mut Graph,
prefix: &str,
in_dim: usize,
neuron: &[usize],
out_dim: usize,
activation: ActivationKind,
resnet_dt: bool,
x: NodeId,
) -> NodeId {
let trunk_prefix = format!("{prefix}.hidden");
let hidden_spec = MlpSpec {
param_prefix: &trunk_prefix,
in_dim,
neuron,
activation,
resnet_dt,
};
let hidden = embedding_mlp(g, &hidden_spec, x);
let last = *neuron.last().unwrap_or(&in_dim);
let final_prefix = format!("{prefix}.final");
let final_spec = DenseLayerSpec {
param_prefix: &final_prefix,
in_dim: last,
out_dim,
activation: ActivationKind::Linear,
use_bias: true,
resnet_dt: false,
resnet: false,
};
crate::nn::dense_layer(g, &final_spec, hidden)
}
fn concat_last(
g: &mut Graph,
a: NodeId,
b: NodeId,
nf: usize,
nloc: usize,
dim_a: usize,
dim_b: usize,
) -> NodeId {
let out = Shape::new(&[nf, nloc, dim_a + dim_b], DType::F32);
g.concat(vec![a, b], 2, out)
}
fn atype_mask(
g: &mut Graph,
atype: NodeId,
value: i32,
nf: usize,
nloc: usize,
) -> NodeId {
let total = nf * nloc;
let value_f = value as f32;
let bytes: Vec<u8> = (0..total)
.flat_map(|_| value_f.to_le_bytes())
.collect();
let value_const = g.add_node(
rlx_ir::op::Op::Constant { data: bytes },
vec![],
Shape::new(&[nf, nloc], DType::F32),
);
let atype_f = g.reshape(
atype,
vec![nf as i64, nloc as i64],
Shape::new(&[nf, nloc], DType::F32),
);
let diff = g.sub(atype_f, value_const);
let sq = g.mul(diff, diff);
let one = scalar_const(g, 1.0);
let one_minus_sq = g.sub(one, sq);
g.activation(
rlx_ir::op::Activation::Relu,
one_minus_sq,
g.shape(one_minus_sq).clone(),
)
}
fn add_type_bias(
g: &mut Graph,
x: NodeId,
atype: NodeId,
ntypes: usize,
out_dim: usize,
param_name: &str,
) -> NodeId {
let table = g.param(param_name, Shape::new(&[ntypes, out_dim], DType::F32));
let bias = g.gather_(table, atype, 0);
g.add(x, bias)
}
fn apply_exclude_mask(g: &mut Graph, x: NodeId, mask: Option<NodeId>) -> NodeId {
let Some(m) = mask else { return x };
let s = g.shape(m).clone();
let mut dims = s.dims().to_vec();
dims.push(rlx_ir::Dim::Static(1));
let m_3d_shape = Shape::from_dims(&dims, DType::F32);
let m_3d = g.reshape(
m,
s.dims()
.iter()
.map(|d| match d {
rlx_ir::Dim::Static(n) => *n as i64,
_ => -1,
})
.chain(std::iter::once(1i64))
.collect(),
m_3d_shape,
);
g.mul(x, m_3d)
}
pub struct InvarFitting {
pub atom_output: NodeId,
}
pub fn build_invar_fitting(
g: &mut Graph,
cfg: &EnerFittingConfig,
inputs: FittingInputs,
nf: usize,
nloc: usize,
param_prefix: &str,
bias_param_name: &str,
) -> Result<InvarFitting> {
if cfg.activation_function.to_ascii_lowercase().as_str() == "linear" {
bail!("invar fitting: hidden layers must use a non-linear activation");
}
let mut out = build_fitting_trunk(g, cfg, &inputs, nf, nloc, cfg.dim_out, param_prefix)?;
out = add_type_bias(g, out, inputs.atype_loc, cfg.ntypes, cfg.dim_out, bias_param_name);
out = apply_exclude_mask(g, out, inputs.exclude_mask);
Ok(InvarFitting { atom_output: out })
}
pub struct EnerFitting {
pub atom_energy: NodeId,
}
pub fn build_ener_fitting(
g: &mut Graph,
cfg: &EnerFittingConfig,
nf: usize,
nloc: usize,
descriptor: NodeId,
atype: NodeId,
fparam: Option<NodeId>,
aparam: Option<NodeId>,
case_embd: Option<NodeId>,
) -> Result<EnerFitting> {
let inputs = FittingInputs {
descriptor,
atype_loc: atype,
fparam,
aparam,
case_embd,
exclude_mask: None,
gr: None,
};
let invar = build_invar_fitting(
g,
cfg,
inputs,
nf,
nloc,
"fitting",
"fitting.bias_atom_e",
)?;
Ok(EnerFitting {
atom_energy: invar.atom_output,
})
}
pub struct DipoleFitting {
pub dipole: NodeId,
}
pub struct DipoleFittingConfig {
pub base: EnerFittingConfig,
pub embedding_width: usize,
}
pub fn build_dipole_fitting(
g: &mut Graph,
cfg: &DipoleFittingConfig,
inputs: FittingInputs,
nf: usize,
nloc: usize,
) -> Result<DipoleFitting> {
let gr = inputs.gr.ok_or_else(|| {
anyhow::anyhow!("dipole fitting requires the gr (equivariant rep) input")
})?;
let mw = cfg.embedding_width;
let mut cfg_inner = cfg.base.clone();
cfg_inner.dim_out = mw;
let trunk = build_fitting_trunk(g, &cfg_inner, &inputs, nf, nloc, mw, "fitting.dipole")?;
let trunk_4d_shape = Shape::new(&[nf, nloc, 1, mw], DType::F32);
let trunk_4d = g.reshape(
trunk,
vec![nf as i64, nloc as i64, 1, mw as i64],
trunk_4d_shape,
);
let prod = g.mm(trunk_4d, gr); let out_shape = Shape::new(&[nf, nloc, 3], DType::F32);
let mut out = g.reshape(prod, vec![nf as i64, nloc as i64, 3], out_shape);
out = apply_exclude_mask(g, out, inputs.exclude_mask);
Ok(DipoleFitting { dipole: out })
}
pub struct PolarFitting {
pub polar: NodeId,
}
pub struct PolarFittingConfig {
pub base: EnerFittingConfig,
pub embedding_width: usize,
pub fit_diag: bool,
pub shift_diag: bool,
}
pub fn build_polar_fitting(
g: &mut Graph,
cfg: &PolarFittingConfig,
inputs: FittingInputs,
nf: usize,
nloc: usize,
) -> Result<PolarFitting> {
let gr = inputs.gr.ok_or_else(|| {
anyhow::anyhow!("polar fitting requires the gr (equivariant rep) input")
})?;
let mw = cfg.embedding_width;
let net_out_dim = if cfg.fit_diag { mw } else { mw * mw };
let mut cfg_inner = cfg.base.clone();
cfg_inner.dim_out = net_out_dim;
let trunk =
build_fitting_trunk(g, &cfg_inner, &inputs, nf, nloc, net_out_dim, "fitting.polar")?;
let scale_table = g.param(
"fitting.polar.scale",
Shape::new(&[cfg.base.ntypes, 1], DType::F32),
);
let scale = g.gather_(scale_table, inputs.atype_loc, 0); let trunk_scaled = g.mul(trunk, scale);
let inner = if cfg.fit_diag {
let trunk_4d_shape = Shape::new(&[nf, nloc, mw, 1], DType::F32);
let t4 = g.reshape(
trunk_scaled,
vec![nf as i64, nloc as i64, mw as i64, 1],
trunk_4d_shape,
);
g.mul(t4, gr)
} else {
let mat_shape = Shape::new(&[nf, nloc, mw, mw], DType::F32);
let mat = g.reshape(
trunk_scaled,
vec![nf as i64, nloc as i64, mw as i64, mw as i64],
mat_shape,
);
let mat_t = g.transpose_(mat, vec![0, 1, 3, 2]);
let sym = g.add(mat, mat_t);
let half = scalar_const(g, 0.5);
let sym = g.mul(sym, half);
g.mm(sym, gr) };
let gr_t = g.transpose_(gr, vec![0, 1, 3, 2]); let mut out = g.mm(gr_t, inner);
if cfg.shift_diag {
let const_mat = g.param(
"fitting.polar.constant_matrix",
Shape::new(&[cfg.base.ntypes, 1], DType::F32),
);
let cm = g.gather_(const_mat, inputs.atype_loc, 0); let bias_scalar = g.mul(cm, scale);
let eye = identity_3x3(g, nf, nloc);
let bias_shape = Shape::new(&[nf, nloc, 1, 1], DType::F32);
let bias_scalar_4d = g.reshape(
bias_scalar,
vec![nf as i64, nloc as i64, 1, 1],
bias_shape,
);
let bias = g.mul(eye, bias_scalar_4d);
out = g.add(out, bias);
}
out = apply_exclude_mask(g, out, inputs.exclude_mask);
Ok(PolarFitting { polar: out })
}
fn identity_3x3(g: &mut Graph, nf: usize, nloc: usize) -> NodeId {
let mut data = Vec::with_capacity(nf * nloc * 9);
for _ in 0..(nf * nloc) {
for i in 0..3 {
for j in 0..3 {
data.push(if i == j { 1.0f32 } else { 0.0f32 });
}
}
}
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
g.add_node(
rlx_ir::op::Op::Constant { data: bytes },
vec![],
Shape::new(&[nf, nloc, 3, 3], DType::F32),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::EnerFittingConfig;
fn ener_cfg(ntypes: usize, dd: usize) -> EnerFittingConfig {
EnerFittingConfig {
ntypes,
dim_descrpt: dd,
dim_out: 1,
neuron: vec![16, 16],
resnet_dt: true,
numb_fparam: 0,
numb_aparam: 0,
dim_case_embd: 0,
activation_function: "tanh".into(),
mixed_types: true,
}
}
#[test]
fn invar_per_type_builds() {
let mut cfg = ener_cfg(3, 16);
cfg.mixed_types = false;
let mut g = Graph::new("invar_pt");
let nf = 1;
let nloc = 4;
let descriptor = g.input("d", Shape::new(&[nf, nloc, cfg.dim_descrpt], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let inputs = FittingInputs {
descriptor,
atype_loc: atype,
fparam: None,
aparam: None,
case_embd: None,
exclude_mask: None,
gr: None,
};
let _ = build_invar_fitting(&mut g, &cfg, inputs, nf, nloc, "fit", "fit.bias")
.expect("build");
assert!(g.len() > 20);
}
#[test]
fn dipole_builds() {
let base = ener_cfg(2, 24);
let cfg = DipoleFittingConfig {
base,
embedding_width: 8,
};
let mut g = Graph::new("dipole");
let nf = 1;
let nloc = 4;
let descriptor = g.input("d", Shape::new(&[nf, nloc, cfg.base.dim_descrpt], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let gr = g.input("gr", Shape::new(&[nf, nloc, cfg.embedding_width, 3], DType::F32));
let inputs = FittingInputs {
descriptor,
atype_loc: atype,
fparam: None,
aparam: None,
case_embd: None,
exclude_mask: None,
gr: Some(gr),
};
let _ = build_dipole_fitting(&mut g, &cfg, inputs, nf, nloc).expect("build");
assert!(g.len() > 15);
}
#[test]
fn polar_builds() {
let base = ener_cfg(2, 24);
let cfg = PolarFittingConfig {
base,
embedding_width: 8,
fit_diag: false,
shift_diag: true,
};
let mut g = Graph::new("polar");
let nf = 1;
let nloc = 4;
let descriptor = g.input("d", Shape::new(&[nf, nloc, cfg.base.dim_descrpt], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let gr = g.input("gr", Shape::new(&[nf, nloc, cfg.embedding_width, 3], DType::F32));
let inputs = FittingInputs {
descriptor,
atype_loc: atype,
fparam: None,
aparam: None,
case_embd: None,
exclude_mask: None,
gr: Some(gr),
};
let _ = build_polar_fitting(&mut g, &cfg, inputs, nf, nloc).expect("build");
assert!(g.len() > 25);
}
}