use integral_math::solid_harmonics::{c2s_matrix, monomial_to_raw_factor};
use crate::shell::{Shell, ShellKind};
pub(crate) fn shell_transform(s: &Shell) -> Option<Vec<f64>> {
match s.kind() {
ShellKind::Cartesian => None,
ShellKind::Spherical => {
let l = s.l();
let ratio = monomial_to_raw_factor(l);
Some(c2s_matrix(l).iter().map(|&c| ratio * c).collect())
}
}
}
fn apply_axis(
tensor: &[f64],
dims: &[usize],
axis: usize,
mat: &[f64],
n_out: usize,
) -> (Vec<f64>, Vec<usize>) {
let n_in = dims[axis];
debug_assert_eq!(mat.len(), n_out * n_in);
let outer: usize = dims[..axis].iter().product();
let inner: usize = dims[axis + 1..].iter().product();
let mut new_dims = dims.to_vec();
new_dims[axis] = n_out;
let mut out = vec![0.0_f64; outer * n_out * inner];
for o in 0..outer {
for q in 0..n_out {
let dst = (o * n_out + q) * inner;
for j in 0..n_in {
let m = mat[q * n_in + j];
if m == 0.0 {
continue;
}
let src = (o * n_in + j) * inner;
for s in 0..inner {
out[dst + s] += m * tensor[src + s];
}
}
}
}
(out, new_dims)
}
pub(crate) fn transform_block(
block: Vec<f64>,
dims: &[usize],
mats: &[Option<&[f64]>],
) -> Vec<f64> {
debug_assert_eq!(dims.len(), mats.len());
let mut cur = block;
let mut cur_dims = dims.to_vec();
for (axis, m) in mats.iter().enumerate() {
if let Some(mat) = m {
let n_out = mat.len() / dims[axis];
let (nt, nd) = apply_axis(&cur, &cur_dims, axis, mat, n_out);
cur = nt;
cur_dims = nd;
}
}
cur
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cartesian_shell_is_identity() {
let s = Shell::new(2, [0.0, 0.0, 0.0], vec![1.0], vec![1.0]).unwrap();
assert!(shell_transform(&s).is_none());
}
#[test]
fn spherical_d_transform_shape() {
let s = Shell::new_spherical(2, [0.0, 0.0, 0.0], vec![1.0], vec![1.0]).unwrap();
let m = shell_transform(&s).unwrap();
assert_eq!(m.len(), 5 * 6); }
#[test]
fn apply_axis_identity_roundtrip() {
let block = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let id = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let (out, dims) = apply_axis(&block, &[2, 3], 1, &id, 3);
assert_eq!(dims, vec![2, 3]);
assert_eq!(out, block);
}
#[test]
fn transform_block_left_and_right_axes() {
let block = vec![1.0, 2.0, 3.0, 4.0];
let sum_rows = [1.0, 1.0]; let id = [1.0, 0.0, 0.0, 1.0]; let out = transform_block(block, &[2, 2], &[Some(&sum_rows[..]), Some(&id[..])]);
assert_eq!(out, vec![4.0, 6.0]); }
}