use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::op::{BinaryOp, ReduceOp};
use rlx_ir::{DType, Graph, NodeId, Shape};
use crate::nn::scalar_const;
#[derive(Debug, Clone)]
pub struct PairTabConfig {
pub ntypes: usize,
pub nspline: usize,
pub rmin: f32,
pub hh: f32,
pub rcut: f32,
}
pub struct PairTabInputs {
pub pairwise_rr: NodeId,
pub atype_loc: NodeId,
pub nei_atype: NodeId,
pub nlist_mask: NodeId,
}
pub fn build_pair_tab_energy(
g: &mut Graph,
cfg: &PairTabConfig,
inputs: PairTabInputs,
nf: usize,
nloc: usize,
nnei: usize,
) -> Result<NodeId> {
let ntypes = cfg.ntypes;
let nspline = cfg.nspline;
let rmin = scalar_const(g, cfg.rmin);
let hi = scalar_const(g, 1.0 / cfg.hh);
let nspline_minus_one = scalar_const(g, (nspline - 1) as f32);
let shifted = g.sub(inputs.pairwise_rr, rmin);
let uu_total = g.mul(shifted, hi); let zero = scalar_const(g, 0.0);
let uu_clamped =
g.binary(BinaryOp::Max, uu_total, zero, g.shape(uu_total).clone());
let uu_clamped = g.binary(
BinaryOp::Min,
uu_clamped,
nspline_minus_one,
g.shape(uu_clamped).clone(),
);
let idx_i32 = g.cast(uu_clamped, DType::I32);
let idx_f = g.cast(idx_i32, DType::F32);
let uu = g.sub(uu_clamped, idx_f);
let ntypes_c = scalar_i32(g, ntypes as i32);
let nspline_c = scalar_i32(g, nspline as i32);
let i_type_4d = g.reshape(
inputs.atype_loc,
vec![nf as i64, nloc as i64, 1],
Shape::new(&[nf, nloc, 1], DType::I32),
);
let zero_i32_n = i32_zero_tensor(g, &[nf, nloc, nnei]);
let i_type_3d = g.binary(
BinaryOp::Add,
i_type_4d,
zero_i32_n,
Shape::new(&[nf, nloc, nnei], DType::I32),
);
let i_scaled = g.binary(
BinaryOp::Mul,
i_type_3d,
ntypes_c,
g.shape(i_type_3d).clone(),
);
let pair_id = g.binary(
BinaryOp::Add,
i_scaled,
inputs.nei_atype,
g.shape(i_scaled).clone(),
);
let pair_id_scaled = g.binary(
BinaryOp::Mul,
pair_id,
nspline_c,
g.shape(pair_id).clone(),
);
let flat_idx = g.binary(
BinaryOp::Add,
pair_id_scaled,
idx_i32,
g.shape(pair_id_scaled).clone(),
);
let coef = g.param(
"pairtab.coef",
Shape::new(&[ntypes, ntypes, nspline, 4], DType::F32),
);
let coef_flat_shape = Shape::new(&[ntypes * ntypes * nspline, 4], DType::F32);
let coef_flat = g.reshape(
coef,
vec![(ntypes * ntypes * nspline) as i64, 4],
coef_flat_shape,
);
let total = nf * nloc * nnei;
let flat_idx_1d = g.reshape(
flat_idx,
vec![total as i64],
Shape::new(&[total], DType::I32),
);
let gathered = g.gather_(coef_flat, flat_idx_1d, 0); let coef_per_pair = g.reshape(
gathered,
vec![nf as i64, nloc as i64, nnei as i64, 4],
Shape::new(&[nf, nloc, nnei, 4], DType::F32),
);
let a0 = g.narrow_(coef_per_pair, 3, 0, 1);
let a1 = g.narrow_(coef_per_pair, 3, 1, 1);
let a2 = g.narrow_(coef_per_pair, 3, 2, 1);
let a3 = g.narrow_(coef_per_pair, 3, 3, 1);
let squeeze = |g: &mut Graph, x: NodeId| {
g.reshape(
x,
vec![nf as i64, nloc as i64, nnei as i64],
Shape::new(&[nf, nloc, nnei], DType::F32),
)
};
let a0 = squeeze(g, a0);
let a1 = squeeze(g, a1);
let a2 = squeeze(g, a2);
let a3 = squeeze(g, a3);
let u2 = g.mul(uu, uu);
let u3 = g.mul(u2, uu);
let t1 = g.mul(a1, uu);
let t2 = g.mul(a2, u2);
let t3 = g.mul(a3, u3);
let mut ener = g.add(a0, t1);
ener = g.add(ener, t2);
ener = g.add(ener, t3);
let rcut_c = scalar_const(g, cfg.rcut);
let _ = rcut_c;
let inside_mask = step_lt(g, inputs.pairwise_rr, cfg.rcut, nf, nloc, nnei);
let ener = g.mul(ener, inside_mask);
let ener = g.mul(ener, inputs.nlist_mask);
let sum_shape = Shape::new(&[nf, nloc], DType::F32);
let summed = g.reduce(ener, ReduceOp::Sum, vec![2], false, sum_shape);
let half = scalar_const(g, 0.5);
let halved = g.mul(summed, half);
let out_shape = Shape::new(&[nf, nloc, 1], DType::F32);
let out = g.reshape(halved, vec![nf as i64, nloc as i64, 1], out_shape);
Ok(out)
}
pub fn build_zbl_weight(
g: &mut Graph,
sigma: NodeId,
sw_rmin: f32,
sw_rmax: f32,
nf: usize,
nloc: usize,
) -> NodeId {
let rmin = scalar_const(g, sw_rmin);
let inv_rng = scalar_const(g, 1.0 / (sw_rmax - sw_rmin));
let one = scalar_const(g, 1.0);
let zero = scalar_const(g, 0.0);
let shifted = g.sub(sigma, rmin);
let u = g.mul(shifted, inv_rng);
let u = g.binary(BinaryOp::Max, u, zero, g.shape(u).clone());
let u = g.binary(BinaryOp::Min, u, one, g.shape(u).clone());
let u2 = g.mul(u, u);
let u3 = g.mul(u2, u);
let u4 = g.mul(u3, u);
let u5 = g.mul(u4, u);
let c10 = scalar_const(g, 10.0);
let c15 = scalar_const(g, 15.0);
let c6 = scalar_const(g, 6.0);
let t1 = g.mul(c10, u3);
let t2 = g.mul(c15, u4);
let t3 = g.mul(c6, u5);
let smooth = g.sub(one, t1);
let smooth = g.add(smooth, t2);
let smooth = g.sub(smooth, t3);
let out_shape = Shape::new(&[nf, nloc, 1], DType::F32);
g.reshape(smooth, vec![nf as i64, nloc as i64, 1], out_shape)
}
pub fn combine_dp_zbl(g: &mut Graph, e_dp: NodeId, e_zbl: NodeId, c: NodeId) -> NodeId {
let one = scalar_const(g, 1.0);
let one_minus_c = g.sub(one, c);
let a = g.mul(one_minus_c, e_dp);
let b = g.mul(c, e_zbl);
g.add(a, b)
}
fn scalar_i32(g: &mut Graph, v: i32) -> NodeId {
let bytes = v.to_le_bytes().to_vec();
g.add_node(
rlx_ir::op::Op::Constant { data: bytes },
vec![],
Shape::new(&[1], DType::I32),
)
}
fn i32_zero_tensor(g: &mut Graph, shape: &[usize]) -> NodeId {
let n: usize = shape.iter().product();
let bytes = vec![0u8; n * 4];
g.add_node(
rlx_ir::op::Op::Constant { data: bytes },
vec![],
Shape::new(shape, DType::I32),
)
}
fn step_lt(
g: &mut Graph,
x: NodeId,
threshold: f32,
nf: usize,
nloc: usize,
nnei: usize,
) -> NodeId {
let _ = (nf, nloc, nnei);
let thr = scalar_const(g, threshold);
let diff = g.sub(x, thr);
let neg_diff = g.neg(diff);
let zero_c = scalar_const(g, 0.0);
let neg_diff_shape = g.shape(neg_diff).clone();
let mask_unclamped = g.binary(BinaryOp::Max, neg_diff, zero_c, neg_diff_shape);
let big = scalar_const(g, 1e30);
let scaled = g.mul(mask_unclamped, big);
let one_c = scalar_const(g, 1.0);
let scaled_shape = g.shape(scaled).clone();
g.binary(BinaryOp::Min, scaled, one_c, scaled_shape)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zbl_weight_builds() {
let mut g = Graph::new("zbl_weight");
let sigma = g.input("sigma", Shape::new(&[1, 4], DType::F32));
let w = build_zbl_weight(&mut g, sigma, 0.5, 1.5, 1, 4);
let s = g.shape(w);
assert_eq!(s.dim(0), rlx_ir::Dim::Static(1));
assert_eq!(s.dim(2), rlx_ir::Dim::Static(1));
}
#[test]
fn pair_tab_energy_builds() {
let mut g = Graph::new("pair_tab");
let nf = 1;
let nloc = 2;
let nnei = 8;
let cfg = PairTabConfig {
ntypes: 2,
nspline: 16,
rmin: 0.0,
hh: 0.5,
rcut: 6.0,
};
let inputs = PairTabInputs {
pairwise_rr: g.input("rr", Shape::new(&[nf, nloc, nnei], DType::F32)),
atype_loc: g.input("atype", Shape::new(&[nf, nloc], DType::I32)),
nei_atype: g.input("nei_atype", Shape::new(&[nf, nloc, nnei], DType::I32)),
nlist_mask: g.input("nlist_mask", Shape::new(&[nf, nloc, nnei], DType::F32)),
};
let e = build_pair_tab_energy(&mut g, &cfg, inputs, nf, nloc, nnei).expect("build");
let s = g.shape(e);
assert_eq!(s.dim(2), rlx_ir::Dim::Static(1));
}
}