use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};
pub fn reduce_atomic_sum(g: &mut Graph, atomic: NodeId, atom_axis: usize) -> NodeId {
let in_shape = g.shape(atomic).clone();
let mut out_dims: Vec<rlx_ir::Dim> = in_shape.dims().to_vec();
out_dims.remove(atom_axis);
let out_shape = Shape::from_dims(&out_dims, in_shape.dtype());
g.reduce(atomic, ReduceOp::Sum, vec![atom_axis], false, out_shape)
}
pub fn reduce_atomic_mean(g: &mut Graph, atomic: NodeId, atom_axis: usize) -> NodeId {
let in_shape = g.shape(atomic).clone();
let mut out_dims: Vec<rlx_ir::Dim> = in_shape.dims().to_vec();
out_dims.remove(atom_axis);
let out_shape = Shape::from_dims(&out_dims, in_shape.dtype());
g.reduce(atomic, ReduceOp::Mean, vec![atom_axis], false, out_shape)
}
pub fn build_force_grad_graph(forward: &Graph, target: NodeId, wrt: &[NodeId]) -> Graph {
let _ = (target, wrt);
forward.clone()
}
pub fn build_virial_via_position(
g: &mut Graph,
forces: NodeId,
positions: NodeId,
nf: usize,
nloc: usize,
) -> NodeId {
let f_4d_shape = Shape::new(&[nf, nloc, 3, 1], DType::F32);
let f_4d = g.reshape(forces, vec![nf as i64, nloc as i64, 3, 1], f_4d_shape);
let r_4d_shape = Shape::new(&[nf, nloc, 1, 3], DType::F32);
let r_4d = g.reshape(positions, vec![nf as i64, nloc as i64, 1, 3], r_4d_shape);
let outer = g.mul(f_4d, r_4d); let sum_shape = Shape::new(&[nf, 3, 3], DType::F32);
let sum = g.reduce(outer, ReduceOp::Sum, vec![1], false, sum_shape);
g.neg(sum)
}
pub fn build_total_energy(g: &mut Graph, atomic_energy: NodeId, nf: usize) -> Result<NodeId> {
let nloc_reduce = reduce_atomic_sum(g, atomic_energy, 1); let _ = nf;
let scalar_shape = Shape::new(&[1], DType::F32);
Ok(g.reduce(nloc_reduce, ReduceOp::Sum, vec![0], false, scalar_shape))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reductions_match_axis() {
let mut g = Graph::new("reduce");
let x = g.input("x", Shape::new(&[2, 4, 3], DType::F32));
let s = reduce_atomic_sum(&mut g, x, 1);
assert_eq!(g.shape(s).rank(), 2);
assert_eq!(g.shape(s).dim(1), rlx_ir::Dim::Static(3));
}
#[test]
fn virial_shape_matches() {
let mut g = Graph::new("virial");
let nf = 1;
let nloc = 4;
let f = g.input("f", Shape::new(&[nf, nloc, 3], DType::F32));
let r = g.input("r", Shape::new(&[nf, nloc, 3], DType::F32));
let v = build_virial_via_position(&mut g, f, r, nf, nloc);
let s = g.shape(v);
assert_eq!(s.rank(), 3);
assert_eq!(s.dim(1), rlx_ir::Dim::Static(3));
assert_eq!(s.dim(2), rlx_ir::Dim::Static(3));
}
}