use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};
#[derive(Debug, Clone)]
pub struct SpinConfig {
pub ntypes_real: usize,
}
pub struct SpinExpanded {
pub coord_spin: NodeId,
pub atype_spin: NodeId,
pub coord_corr: NodeId,
}
pub fn build_spin_inputs(
g: &mut Graph,
cfg: &SpinConfig,
coord: NodeId,
atype: NodeId,
spin: NodeId,
nf: usize,
nloc: usize,
) -> Result<SpinExpanded> {
let two_nt = 2 * cfg.ntypes_real;
let ntypes_const = i32_filled(g, cfg.ntypes_real as i32, &[nf, nloc]);
let virtual_atype = g.add(atype, ntypes_const);
let atype_spin_shape = Shape::new(&[nf, 2 * nloc], DType::I32);
let atype_spin = g.concat(vec![atype, virtual_atype], 1, atype_spin_shape);
let vsm = g.param(
"spin.virtual_scale_mask",
Shape::new(&[two_nt], DType::F32),
);
let scale_per_atom = g.gather_(vsm, atype, 0); let scale_3d_shape = Shape::new(&[nf, nloc, 1], DType::F32);
let scale_3d = g.reshape(
scale_per_atom,
vec![nf as i64, nloc as i64, 1],
scale_3d_shape,
);
let spin_dist = g.mul(spin, scale_3d); let virtual_coord = g.add(coord, spin_dist);
let coord_spin_shape = Shape::new(&[nf, 2 * nloc, 3], DType::F32);
let coord_spin = g.concat(vec![coord, virtual_coord], 1, coord_spin_shape);
let zero_real = zero_tensor(g, &[nf, nloc, 3]);
let neg_spin_dist = g.neg(spin_dist);
let coord_corr_shape = Shape::new(&[nf, 2 * nloc, 3], DType::F32);
let coord_corr = g.concat(vec![zero_real, neg_spin_dist], 1, coord_corr_shape);
Ok(SpinExpanded {
coord_spin,
atype_spin,
coord_corr,
})
}
pub fn split_spin_output(
g: &mut Graph,
output: NodeId,
nf: usize,
nloc: usize,
dim: usize,
) -> (NodeId, NodeId) {
let real = g.narrow_(output, 1, 0, nloc);
let virt = g.narrow_(output, 1, nloc, nloc);
let _ = (nf, dim);
(real, virt)
}
fn i32_filled(g: &mut Graph, value: i32, shape: &[usize]) -> NodeId {
let n: usize = shape.iter().product();
let bytes: Vec<u8> = (0..n).flat_map(|_| value.to_le_bytes()).collect();
g.add_node(
rlx_ir::op::Op::Constant { data: bytes },
vec![],
Shape::new(shape, DType::I32),
)
}
fn zero_tensor(g: &mut Graph, shape: &[usize]) -> NodeId {
let n: usize = shape.iter().product();
g.add_node(
rlx_ir::op::Op::Constant {
data: vec![0u8; n * 4],
},
vec![],
Shape::new(shape, DType::F32),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn spin_inputs_build() {
let mut g = Graph::new("spin");
let nf = 1;
let nloc = 4;
let cfg = SpinConfig { ntypes_real: 2 };
let coord = g.input("coord", Shape::new(&[nf, nloc, 3], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let spin = g.input("spin", Shape::new(&[nf, nloc, 3], DType::F32));
let out = build_spin_inputs(&mut g, &cfg, coord, atype, spin, nf, nloc)
.expect("build");
let s_coord = g.shape(out.coord_spin);
assert_eq!(s_coord.dim(1), rlx_ir::Dim::Static(2 * nloc));
let s_atype = g.shape(out.atype_spin);
assert_eq!(s_atype.dim(1), rlx_ir::Dim::Static(2 * nloc));
}
#[test]
fn spin_output_splits() {
let mut g = Graph::new("spin_split");
let nf = 1;
let nloc = 4;
let backbone =
g.input("out", Shape::new(&[nf, 2 * nloc, 1], DType::F32));
let (real, virt) = split_spin_output(&mut g, backbone, nf, nloc, 1);
assert_eq!(g.shape(real).dim(1), rlx_ir::Dim::Static(nloc));
assert_eq!(g.shape(virt).dim(1), rlx_ir::Dim::Static(nloc));
}
}