use integral_math::am::{cart_components, cart_index, n_cart};
#[derive(Debug, Clone, Copy)]
pub struct AxisDeriv {
pub l: usize,
pub cart_axis: usize,
pub outer: usize,
pub inner: usize,
}
pub fn accumulate_center_derivative(
d: AxisDeriv,
scale: f64,
raised: &[f64],
lowered: Option<&[f64]>,
out: &mut [f64],
) {
let AxisDeriv {
l,
cart_axis,
outer,
inner,
} = d;
debug_assert!(cart_axis < 3, "cart_axis must be 0, 1, or 2");
let nl = n_cart(l);
let np1 = n_cart(l + 1);
debug_assert_eq!(raised.len(), outer * np1 * inner, "raised block size");
debug_assert_eq!(out.len(), outer * nl * inner, "output block size");
if l > 0 {
debug_assert!(lowered.is_some(), "lowered block required for l > 0");
debug_assert_eq!(
lowered.unwrap().len(),
outer * n_cart(l - 1) * inner,
"lowered block size"
);
}
let nm1 = if l > 0 { n_cart(l - 1) } else { 0 };
for (a_idx, a) in cart_components(l).into_iter().enumerate() {
let mut ar = a;
ar[cart_axis] += 1;
let r_idx = cart_index(ar);
let pw = a[cart_axis];
let l_idx = if pw >= 1 {
let mut al = a;
al[cart_axis] -= 1;
Some(cart_index(al))
} else {
None
};
for o in 0..outer {
let out_row = (o * nl + a_idx) * inner;
let r_row = (o * np1 + r_idx) * inner;
let l_row = l_idx.map(|li| (o * nm1 + li) * inner);
for s in 0..inner {
let mut v = raised[r_row + s];
if let Some(lr) = l_row {
v -= (pw as f64) * lowered.unwrap()[lr + s];
}
out[out_row + s] += scale * v;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn desc(l: usize, cart_axis: usize, outer: usize, inner: usize) -> AxisDeriv {
AxisDeriv {
l,
cart_axis,
outer,
inner,
}
}
#[test]
fn s_function_derivative_is_raise_only() {
let raised = vec![10.0, 20.0, 30.0]; let mut out = vec![0.0; 1];
accumulate_center_derivative(desc(0, 0, 1, 1), 1.0, &raised, None, &mut out);
assert_eq!(out[0], 10.0);
let mut outy = vec![0.0; 1];
accumulate_center_derivative(desc(0, 1, 1, 1), 1.0, &raised, None, &mut outy);
assert_eq!(outy[0], 20.0);
}
#[test]
fn p_function_derivative_has_raise_and_lower() {
let d = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let s = vec![7.0];
let mut out = vec![0.0; 3]; accumulate_center_derivative(desc(1, 0, 1, 1), 1.0, &d, Some(&s), &mut out);
assert_eq!(out[0], 1.0 - 7.0);
assert_eq!(out[1], 2.0);
assert_eq!(out[2], 3.0);
}
#[test]
fn scale_multiplies_and_accumulates() {
let raised = vec![10.0, 0.0, 0.0];
let mut out = vec![0.0; 1];
accumulate_center_derivative(desc(0, 0, 1, 1), 2.0, &raised, None, &mut out);
accumulate_center_derivative(desc(0, 0, 1, 1), 3.0, &raised, None, &mut out);
assert_eq!(out[0], (2.0 + 3.0) * 10.0);
}
#[test]
fn outer_indexing_places_correctly() {
let raised = vec![1.0, 0.0, 0.0, 2.0, 0.0, 0.0];
let mut out = vec![0.0; 2];
accumulate_center_derivative(desc(0, 0, 2, 1), 1.0, &raised, None, &mut out);
assert_eq!(out, vec![1.0, 2.0]);
}
#[test]
fn inner_indexing_places_correctly() {
let raised = vec![
11.0, 12.0, 21.0, 22.0, 31.0, 32.0, ];
let mut out = vec![0.0; 2];
accumulate_center_derivative(desc(0, 0, 1, 2), 1.0, &raised, None, &mut out);
assert_eq!(out, vec![11.0, 12.0]); }
}