gam 0.2.3

Generalized penalized likelihood engine
use ndarray::{ArrayView1, ArrayView2, ArrayView3, ArrayViewD, IxDyn};

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum EquivariantGroup {
    SO2,
    SO3,
    R1,
    Trivial,
}

impl EquivariantGroup {
    fn name(self) -> &'static str {
        match self {
            EquivariantGroup::SO2 => "SO2",
            EquivariantGroup::SO3 => "SO3",
            EquivariantGroup::R1 => "R1",
            EquivariantGroup::Trivial => "TRIVIAL",
        }
    }

    fn rep_dim(self) -> usize {
        match self {
            EquivariantGroup::SO2 => 2,
            EquivariantGroup::SO3 => 3,
            EquivariantGroup::R1 | EquivariantGroup::Trivial => 1,
        }
    }

    fn parse(group: &str) -> Result<Self, String> {
        match group {
            "SO2" => Ok(EquivariantGroup::SO2),
            "SO3" => Ok(EquivariantGroup::SO3),
            "R1" => Ok(EquivariantGroup::R1),
            "Trivial" | "TRIVIAL" => Ok(EquivariantGroup::Trivial),
            other => Err(format!(
                "group must be 'SO2', 'SO3', 'R1', or 'Trivial'; got {other:?}"
            )),
        }
    }
}

fn dynamic_value(
    values: &ArrayViewD<'_, f64>,
    index: &[usize],
    label: &str,
) -> Result<f64, String> {
    values
        .get(IxDyn(index))
        .copied()
        .ok_or_else(|| format!("{label}: index {index:?} out of bounds"))
}

fn invert_small_matrix(matrix: &[Vec<f64>], context: &str) -> Result<Vec<Vec<f64>>, String> {
    let n = matrix.len();
    if n == 0 {
        return Err(format!("{context}: matrix must not be empty"));
    }
    let mut aug = vec![vec![0.0_f64; 2 * n]; n];
    for i in 0..n {
        if matrix[i].len() != n {
            return Err(format!("{context}: matrix must be square"));
        }
        for j in 0..n {
            let value = matrix[i][j];
            if !value.is_finite() {
                return Err(format!("{context}: matrix entry [{i},{j}] is not finite"));
            }
            aug[i][j] = value;
        }
        aug[i][n + i] = 1.0;
    }
    for col in 0..n {
        let mut pivot = col;
        let mut pivot_abs = aug[col][col].abs();
        for row in (col + 1)..n {
            let candidate = aug[row][col].abs();
            if candidate > pivot_abs {
                pivot = row;
                pivot_abs = candidate;
            }
        }
        if pivot_abs <= 1.0e-12 {
            return Err(format!("{context}: matrix is singular at pivot {col}"));
        }
        if pivot != col {
            aug.swap(pivot, col);
        }
        let scale = aug[col][col];
        for item in &mut aug[col] {
            *item /= scale;
        }
        for row in 0..n {
            if row == col {
                continue;
            }
            let factor = aug[row][col];
            if factor == 0.0 {
                continue;
            }
            for j in 0..(2 * n) {
                aug[row][j] -= factor * aug[col][j];
            }
        }
    }
    let mut inverse = vec![vec![0.0_f64; n]; n];
    for i in 0..n {
        for j in 0..n {
            inverse[i][j] = aug[i][n + j];
        }
    }
    Ok(inverse)
}

fn square_matmul(left: &[Vec<f64>], right: &[Vec<f64>], n: usize) -> Vec<Vec<f64>> {
    let mut out = vec![vec![0.0_f64; n]; n];
    for i in 0..n {
        for k in 0..n {
            let left_ik = left[i][k];
            for j in 0..n {
                out[i][j] += left_ik * right[k][j];
            }
        }
    }
    out
}

fn equivariant_rotation(
    group: EquivariantGroup,
    g: ArrayViewD<'_, f64>,
    batch: usize,
    atom: usize,
) -> Result<Vec<Vec<f64>>, String> {
    match group {
        EquivariantGroup::SO2 => {
            if g.ndim() != 2 {
                return Err("SO2 group coordinates must have shape (B, A)".to_string());
            }
            let theta = dynamic_value(&g, &[batch, atom], "SO2 group coordinates")?;
            let (s, c) = theta.sin_cos();
            Ok(vec![vec![c, -s], vec![s, c]])
        }
        EquivariantGroup::SO3 => {
            if g.ndim() != 3 || g.shape()[2] != 3 {
                return Err("SO3 group coordinates must have shape (B, A, 3)".to_string());
            }
            let ox = dynamic_value(&g, &[batch, atom, 0], "SO3 group coordinates")?;
            let oy = dynamic_value(&g, &[batch, atom, 1], "SO3 group coordinates")?;
            let oz = dynamic_value(&g, &[batch, atom, 2], "SO3 group coordinates")?;
            let angle = (ox * ox + oy * oy + oz * oz).sqrt().max(1.0e-12);
            let ax = ox / angle;
            let ay = oy / angle;
            let az = oz / angle;
            let k = vec![vec![0.0, -az, ay], vec![az, 0.0, -ax], vec![-ay, ax, 0.0]];
            let kk = square_matmul(&k, &k, 3);
            let s = angle.sin();
            let one_minus_c = 1.0 - angle.cos();
            let mut out = vec![vec![0.0_f64; 3]; 3];
            for i in 0..3 {
                for j in 0..3 {
                    out[i][j] =
                        if i == j { 1.0 } else { 0.0 } + s * k[i][j] + one_minus_c * kk[i][j];
                }
            }
            Ok(out)
        }
        EquivariantGroup::R1 | EquivariantGroup::Trivial => {
            if g.ndim() != 2 {
                return Err(format!(
                    "{} group coordinates must have shape (B, A)",
                    group.name()
                ));
            }
            Ok(vec![vec![1.0]])
        }
    }
}

/// Computes a weighted projection residual for each atom and batch:
/// 0.5 * z[b, a] * ||P_perp W_a rho(g_ba)||_F^2, where
/// P_perp = I - W_a (W_a^T W_a + eps I)^-1 W_a^T.
///
/// This is not the commutator W rho(g) - rho(g) W. It measures the
/// Frobenius-norm energy in the rotated basis columns that falls outside the
/// column space of W_a.
pub fn equivariant_penalty_value(
    group: &str,
    w: ArrayView3<'_, f64>,
    g: ArrayViewD<'_, f64>,
    z: ArrayView2<'_, f64>,
    weight: f64,
    ard_weight: f64,
    log_bandwidth: Option<ArrayView1<'_, f64>>,
) -> Result<f64, String> {
    if !(weight.is_finite() && weight > 0.0) {
        return Err(format!("weight must be finite and > 0, got {weight}"));
    }
    if !(ard_weight.is_finite() && ard_weight >= 0.0) {
        return Err(format!(
            "ard_weight must be finite and >= 0, got {ard_weight}"
        ));
    }
    let group = EquivariantGroup::parse(group)?;
    let expected_r = group.rep_dim();
    let (n_atoms, ambient_dim, rep_dim) = (w.shape()[0], w.shape()[1], w.shape()[2]);
    if rep_dim != expected_r {
        return Err(format!(
            "{} requires W.shape[2] == {expected_r}; got {rep_dim}",
            group.name()
        ));
    }
    if z.ncols() != n_atoms {
        return Err(format!("z has {} atoms but W has {n_atoms}", z.ncols()));
    }
    let batches = z.nrows();
    if g.ndim() < 2 || g.shape()[0] != batches || g.shape()[1] != n_atoms {
        return Err(format!(
            "g leading dimensions must match z shape ({batches}, {n_atoms})"
        ));
    }
    if let Some(log_bw) = log_bandwidth.as_ref() {
        if log_bw.len() != n_atoms {
            return Err(format!(
                "log_bandwidth length {} must equal atom count {n_atoms}",
                log_bw.len()
            ));
        }
    }

    let mut projection_total = 0.0_f64;
    for atom in 0..n_atoms {
        let mut wtw = vec![vec![0.0_f64; rep_dim]; rep_dim];
        for r1 in 0..rep_dim {
            for r2 in 0..rep_dim {
                let mut acc = 0.0_f64;
                for d in 0..ambient_dim {
                    acc += w[[atom, d, r1]] * w[[atom, d, r2]];
                }
                if r1 == r2 {
                    acc += 1.0e-6;
                }
                wtw[r1][r2] = acc;
            }
        }
        let inv = invert_small_matrix(&wtw, "equivariant_penalty_value WtW")?;
        for batch in 0..batches {
            let rotation = equivariant_rotation(group, g.view(), batch, atom)?;
            let mut w_rot = vec![vec![0.0_f64; rep_dim]; ambient_dim];
            for d in 0..ambient_dim {
                for s_col in 0..rep_dim {
                    let mut acc = 0.0_f64;
                    for r_col in 0..rep_dim {
                        acc += w[[atom, d, r_col]] * rotation[r_col][s_col];
                    }
                    w_rot[d][s_col] = acc;
                }
            }
            let mut cross = vec![vec![0.0_f64; rep_dim]; rep_dim];
            for r_col in 0..rep_dim {
                for s_col in 0..rep_dim {
                    let mut acc = 0.0_f64;
                    for d in 0..ambient_dim {
                        acc += w[[atom, d, r_col]] * w_rot[d][s_col];
                    }
                    cross[r_col][s_col] = acc;
                }
            }
            let solve = square_matmul(&inv, &cross, rep_dim);
            let mut sq = 0.0_f64;
            for d in 0..ambient_dim {
                for s_col in 0..rep_dim {
                    let mut projection = 0.0_f64;
                    for r_col in 0..rep_dim {
                        projection += w[[atom, d, r_col]] * solve[r_col][s_col];
                    }
                    let residual = w_rot[d][s_col] - projection;
                    sq += residual * residual;
                }
            }
            projection_total += 0.5 * z[[batch, atom]] * sq;
        }
    }

    let mut value = weight * projection_total / ((batches * n_atoms) as f64);
    if let Some(log_bw) = log_bandwidth {
        if ard_weight > 0.0 {
            let mut bw_value = 0.0_f64;
            for bandwidth in log_bw.iter().copied() {
                if !bandwidth.is_finite() {
                    return Err("log_bandwidth entries must be finite".to_string());
                }
                bw_value += 0.5 * (1.0e-3 + bandwidth * bandwidth).ln();
            }
            value += ard_weight * bw_value;
        }
    }
    Ok(value)
}