use crate::families::custom_family::{
AdditiveBlockJacobian, BlockWorkingSet, CustomFamily, ExactNewtonJointGradientEvaluation,
ExactNewtonJointHessianWorkspace, FamilyEvaluation, ParameterBlockSpec, ParameterBlockState,
PenaltyMatrix,
};
use crate::families::gamlss::{FamilyMetadata, ParameterLink};
use crate::families::vector_response::{
MultinomialLogitLikelihood, VectorLikelihood, validate_multinomial_simplex,
};
use crate::matrix::{DenseDesignMatrix, DesignMatrix, SymmetricMatrix};
use crate::pirls::dense_block_xtwx;
use ndarray::{Array1, Array2, Array3, ArrayView2};
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct MultinomialFamily {
pub y_one_hot: Array2<f64>,
pub weights: Array1<f64>,
pub total_classes: usize,
pub design: Arc<Array2<f64>>,
pub penalties: Arc<Vec<PenaltyMatrix>>,
pub penalty_nullspace_dims: Arc<Vec<usize>>,
likelihood: MultinomialLogitLikelihood,
}
impl MultinomialFamily {
pub const fn active_classes(&self) -> usize {
self.total_classes - 1
}
pub fn parameter_names(&self) -> Vec<String> {
(0..self.active_classes())
.map(|a| format!("class_{a}"))
.collect()
}
pub fn parameter_links(&self) -> Vec<ParameterLink> {
vec![ParameterLink::Identity; self.active_classes()]
}
pub fn metadata() -> FamilyMetadata {
FamilyMetadata {
name: "multinomial_logit",
parameternames: &[],
parameter_links: &[],
}
}
pub fn new(
y_one_hot: Array2<f64>,
weights: Array1<f64>,
total_classes: usize,
design: Arc<Array2<f64>>,
penalties: Arc<Vec<PenaltyMatrix>>,
penalty_nullspace_dims: Arc<Vec<usize>>,
) -> Result<Self, String> {
if total_classes < 2 {
return Err(format!(
"MultinomialFamily requires K ≥ 2 classes (got {total_classes})"
));
}
let (n, k) = y_one_hot.dim();
if k != total_classes {
return Err(format!(
"MultinomialFamily: y_one_hot has {k} columns but total_classes = {total_classes}"
));
}
if weights.len() != n {
return Err(format!(
"MultinomialFamily: weights length {} != N = {n}",
weights.len()
));
}
for (i, &v) in weights.iter().enumerate() {
if !(v.is_finite() && v >= 0.0) {
return Err(format!(
"MultinomialFamily: weights[{i}] must be finite and non-negative (got {v})"
));
}
}
if design.nrows() != n {
return Err(format!(
"MultinomialFamily: design has {} rows, expected {n}",
design.nrows()
));
}
let p = design.ncols();
if penalty_nullspace_dims.len() != penalties.len() {
return Err(format!(
"MultinomialFamily: penalty_nullspace_dims length {} != penalties length {}",
penalty_nullspace_dims.len(),
penalties.len()
));
}
for (t, penalty) in penalties.iter().enumerate() {
if penalty.shape() != (p, p) {
return Err(format!(
"MultinomialFamily: penalties[{t}] shape {:?} != (P, P) = ({p}, {p})",
penalty.shape()
));
}
for ((i, j), &v) in penalty.to_dense().indexed_iter() {
if !v.is_finite() {
return Err(format!(
"MultinomialFamily: penalties[{t}][{i},{j}] must be finite (got {v})"
));
}
}
}
validate_multinomial_simplex(y_one_hot.view(), "MultinomialFamily")
.map_err(|e| e.to_string())?;
for ((i, j), &v) in design.indexed_iter() {
if !v.is_finite() {
return Err(format!(
"MultinomialFamily: design[{i},{j}] must be finite (got {v})"
));
}
}
let likelihood = MultinomialLogitLikelihood::with_classes(total_classes)
.map_err(|e| format!("MultinomialFamily: {e}"))?
.with_row_weights(weights.clone())
.map_err(|e| format!("MultinomialFamily: {e}"))?;
Ok(Self {
y_one_hot,
weights,
total_classes,
design,
penalties,
penalty_nullspace_dims,
likelihood,
})
}
pub fn build_block_specs(&self) -> Vec<ParameterBlockSpec> {
let m = self.active_classes();
let n_terms = self.penalties.len();
(0..m)
.map(|a| {
let priority = 100u8.saturating_add(u8::try_from(m - a).unwrap_or(u8::MAX));
let mut spec = ParameterBlockSpec {
name: format!("class_{a}"),
design: DesignMatrix::Dense(DenseDesignMatrix::from(self.design.clone())),
offset: Array1::<f64>::zeros(self.design.nrows()),
penalties: (*self.penalties).clone(),
nullspace_dims: (*self.penalty_nullspace_dims).clone(),
initial_log_lambdas: Array1::<f64>::zeros(n_terms),
initial_beta: None,
gauge_priority: priority,
jacobian_callback: None,
stacked_design: None,
stacked_offset: None,
};
spec.jacobian_callback = Some(Arc::new(AdditiveBlockJacobian {
design: (*self.design).clone(),
own_output: a,
n_family_outputs: m,
}));
spec
})
.collect()
}
pub fn beta_flat_dim(&self) -> usize {
self.active_classes() * self.design.ncols()
}
fn collect_eta_matrix(
&self,
block_states: &[ParameterBlockState],
) -> Result<Array2<f64>, String> {
let m = self.active_classes();
if block_states.len() != m {
return Err(format!(
"MultinomialFamily expects {m} blocks (K-1), got {}",
block_states.len()
));
}
let n = self.weights.len();
let mut eta = Array2::<f64>::zeros((n, m));
for (a, state) in block_states.iter().enumerate() {
if state.eta.len() != n {
return Err(format!(
"MultinomialFamily block {a} eta length {} != N = {n}",
state.eta.len()
));
}
for row in 0..n {
eta[[row, a]] = state.eta[row];
}
}
Ok(eta)
}
fn evaluate_row_kernels(&self, eta: ArrayView2<'_, f64>) -> (f64, Array3<f64>, Array2<f64>) {
let log_lik = self.likelihood.log_lik(eta, self.y_one_hot.view());
let fisher = self.likelihood.hess_block(eta, self.y_one_hot.view());
let grad_eta_logl = self.likelihood.grad_eta(eta, self.y_one_hot.view());
(log_lik, fisher, grad_eta_logl)
}
fn assemble_block_diagonal_working_sets(
&self,
fisher: &Array3<f64>,
grad_eta_logl: &Array2<f64>,
) -> Result<Vec<BlockWorkingSet>, String> {
let n = self.weights.len();
let p = self.design.ncols();
let m = self.active_classes();
let design_view = self.design.view();
let mut sets = Vec::with_capacity(m);
for a in 0..m {
let mut grad = Array1::<f64>::zeros(p);
for i in 0..p {
let mut acc = 0.0_f64;
for row in 0..n {
acc += design_view[[row, i]] * (-grad_eta_logl[[row, a]]);
}
grad[i] = acc;
}
let mut hess = Array2::<f64>::zeros((p, p));
for row in 0..n {
let w_aa = fisher[[row, a, a]];
if w_aa == 0.0 {
continue;
}
for i in 0..p {
let xi = design_view[[row, i]];
if xi == 0.0 {
continue;
}
let scaled = w_aa * xi;
for j in 0..p {
hess[[i, j]] += scaled * design_view[[row, j]];
}
}
}
for i in 0..p {
for j in (i + 1)..p {
let avg = 0.5 * (hess[[i, j]] + hess[[j, i]]);
hess[[i, j]] = avg;
hess[[j, i]] = avg;
}
}
sets.push(BlockWorkingSet::ExactNewton {
gradient: grad,
hessian: SymmetricMatrix::Dense(hess),
});
}
Ok(sets)
}
fn assemble_joint_hessian(&self, fisher: &Array3<f64>) -> Result<Array2<f64>, String> {
dense_block_xtwx(self.design.view(), fisher.view(), None)
.map_err(|e| format!("MultinomialFamily joint Hessian assembly: {e}"))
}
fn assemble_joint_gradient(&self, grad_eta_logl: &Array2<f64>) -> Array1<f64> {
let n = self.weights.len();
let p = self.design.ncols();
let m = self.active_classes();
let design_view = self.design.view();
let mut out = Array1::<f64>::zeros(m * p);
for a in 0..m {
for i in 0..p {
let mut acc = 0.0_f64;
for row in 0..n {
acc += design_view[[row, i]] * grad_eta_logl[[row, a]];
}
out[a * p + i] = acc;
}
}
out
}
fn d_eta_from_d_beta(&self, d_beta_flat: &Array1<f64>) -> Result<Array2<f64>, String> {
let p = self.design.ncols();
let m = self.active_classes();
let n = self.design.nrows();
if d_beta_flat.len() != m * p {
return Err(format!(
"MultinomialFamily direction length {} != (K-1)·P = {}",
d_beta_flat.len(),
m * p
));
}
let mut d_eta = Array2::<f64>::zeros((n, m));
let design_view = self.design.view();
for a in 0..m {
for row in 0..n {
let mut acc = 0.0_f64;
for i in 0..p {
acc += design_view[[row, i]] * d_beta_flat[a * p + i];
}
d_eta[[row, a]] = acc;
}
}
Ok(d_eta)
}
fn row_probabilities(&self, eta: ArrayView2<'_, f64>) -> Array2<f64> {
self.likelihood.probabilities(eta)
}
fn hessian_matvec_into_with_probs(
&self,
probs_full: ArrayView2<'_, f64>,
v: &Array1<f64>,
out: &mut Array1<f64>,
) -> Result<(), String> {
let p = self.design.ncols();
let m = self.active_classes();
let n = self.weights.len();
let total = m * p;
if v.len() != total {
return Err(format!(
"MultinomialHessianWorkspace::hessian_matvec: v len {} != (K-1)·P = {total}",
v.len()
));
}
if out.len() != total {
return Err(format!(
"MultinomialHessianWorkspace::hessian_matvec: out len {} != (K-1)·P = {total}",
out.len()
));
}
out.fill(0.0);
let design = self.design.view();
let mut xv = vec![0.0_f64; m];
for row in 0..n {
let w = self.weights[row];
if w == 0.0 {
continue;
}
let mut s = 0.0_f64;
for b in 0..m {
let mut acc = 0.0_f64;
for j in 0..p {
acc += design[[row, j]] * v[b * p + j];
}
xv[b] = acc;
s += probs_full[[row, b]] * acc;
}
for a in 0..m {
let r = w * probs_full[[row, a]] * (xv[a] - s);
if r == 0.0 {
continue;
}
let base = a * p;
for i in 0..p {
out[base + i] += design[[row, i]] * r;
}
}
}
Ok(())
}
fn hessian_diagonal_with_probs(&self, probs_full: ArrayView2<'_, f64>) -> Array1<f64> {
let p = self.design.ncols();
let m = self.active_classes();
let n = self.weights.len();
let mut out = Array1::<f64>::zeros(m * p);
let design = self.design.view();
for row in 0..n {
let w = self.weights[row];
if w == 0.0 {
continue;
}
for a in 0..m {
let pa = probs_full[[row, a]];
let waa = w * pa * (1.0 - pa);
if waa == 0.0 {
continue;
}
let base = a * p;
for i in 0..p {
let xi = design[[row, i]];
out[base + i] += waa * xi * xi;
}
}
}
out
}
fn directional_fisher_jet(
&self,
eta: ArrayView2<'_, f64>,
d_beta_flat: &Array1<f64>,
) -> Result<Array3<f64>, String> {
let n = self.weights.len();
let m = self.active_classes();
let probs_full = self.row_probabilities(eta);
let d_eta = self.d_eta_from_d_beta(d_beta_flat)?;
let mut out = Array3::<f64>::zeros((n, m, m));
let mut dp = vec![0.0_f64; m];
for row in 0..n {
let w = self.weights[row];
if w == 0.0 {
continue;
}
let mut s = 0.0_f64;
for a in 0..m {
s += probs_full[[row, a]] * d_eta[[row, a]];
}
for a in 0..m {
dp[a] = probs_full[[row, a]] * (d_eta[[row, a]] - s);
}
for a in 0..m {
let pa = probs_full[[row, a]];
out[[row, a, a]] = w * (dp[a] - 2.0 * dp[a] * pa);
for b in (a + 1)..m {
let pb = probs_full[[row, b]];
let off = w * (-(dp[a] * pb + pa * dp[b]));
out[[row, a, b]] = off;
out[[row, b, a]] = off;
}
}
}
Ok(out)
}
fn second_directional_fisher_jet(
&self,
eta: ArrayView2<'_, f64>,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Array3<f64>, String> {
let n = self.weights.len();
let m = self.active_classes();
let probs_full = self.row_probabilities(eta);
let d_eta_u = self.d_eta_from_d_beta(d_beta_u)?;
let d_eta_v = self.d_eta_from_d_beta(d_beta_v)?;
let mut out = Array3::<f64>::zeros((n, m, m));
let mut dp_u = vec![0.0_f64; m];
let mut dp_v = vec![0.0_f64; m];
let mut ddp = vec![0.0_f64; m];
for row in 0..n {
let w = self.weights[row];
if w == 0.0 {
continue;
}
let mut s_u = 0.0_f64;
let mut s_v = 0.0_f64;
for a in 0..m {
s_u += probs_full[[row, a]] * d_eta_u[[row, a]];
s_v += probs_full[[row, a]] * d_eta_v[[row, a]];
}
for a in 0..m {
let pa = probs_full[[row, a]];
dp_u[a] = pa * (d_eta_u[[row, a]] - s_u);
dp_v[a] = pa * (d_eta_v[[row, a]] - s_v);
}
let mut ds_u_dv = 0.0_f64;
for c in 0..m {
ds_u_dv += dp_v[c] * d_eta_u[[row, c]];
}
for a in 0..m {
let pa = probs_full[[row, a]];
ddp[a] = dp_v[a] * (d_eta_u[[row, a]] - s_u) + pa * (-ds_u_dv);
}
for a in 0..m {
let pa = probs_full[[row, a]];
out[[row, a, a]] = w * (ddp[a] - 2.0 * ddp[a] * pa - 2.0 * dp_u[a] * dp_v[a]);
for b in (a + 1)..m {
let pb = probs_full[[row, b]];
let off =
w * (-(ddp[a] * pb + dp_u[a] * dp_v[b] + dp_v[a] * dp_u[b] + pa * ddp[b]));
out[[row, a, b]] = off;
out[[row, b, a]] = off;
}
}
}
Ok(out)
}
}
impl CustomFamily for MultinomialFamily {
fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
true
}
fn has_explicit_joint_hessian(&self) -> bool {
true
}
fn requires_joint_outer_hyper_path(&self) -> bool {
true
}
fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
crate::custom_family::joint_coupled_coefficient_hessian_cost(
self.weights.len() as u64,
specs,
)
}
fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
let eta = self.collect_eta_matrix(block_states)?;
let (log_lik, fisher, grad_eta_logl) = self.evaluate_row_kernels(eta.view());
let working_sets = self.assemble_block_diagonal_working_sets(&fisher, &grad_eta_logl)?;
Ok(FamilyEvaluation {
log_likelihood: log_lik,
blockworking_sets: working_sets,
})
}
fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
let eta = self.collect_eta_matrix(block_states)?;
Ok(self.likelihood.log_lik(eta.view(), self.y_one_hot.view()))
}
fn exact_newton_joint_hessian(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array2<f64>>, String> {
let eta = self.collect_eta_matrix(block_states)?;
let (_, fisher, _) = self.evaluate_row_kernels(eta.view());
let hessian = self.assemble_joint_hessian(&fisher)?;
Ok(Some(hessian))
}
fn exact_newton_joint_gradient_evaluation(
&self,
block_states: &[ParameterBlockState],
block_specs: &[ParameterBlockSpec],
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
assert!(block_specs.len() <= isize::MAX as usize);
let eta = self.collect_eta_matrix(block_states)?;
let log_lik = self.likelihood.log_lik(eta.view(), self.y_one_hot.view());
let grad_eta_logl = self.likelihood.grad_eta(eta.view(), self.y_one_hot.view());
let gradient = self.assemble_joint_gradient(&grad_eta_logl);
Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood: log_lik,
gradient,
}))
}
fn exact_newton_joint_hessian_workspace(
&self,
block_states: &[ParameterBlockState],
block_specs: &[ParameterBlockSpec],
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
assert!(block_specs.len() <= isize::MAX as usize);
let eta = self.collect_eta_matrix(block_states)?;
let probs = self.row_probabilities(eta.view());
Ok(Some(Arc::new(MultinomialHessianWorkspace {
family: self.clone(),
block_states: block_states.to_vec(),
probs,
})))
}
fn exact_newton_joint_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let eta = self.collect_eta_matrix(block_states)?;
let dh_fisher = self.directional_fisher_jet(eta.view(), d_beta_flat)?;
let dh = dense_block_xtwx(self.design.view(), dh_fisher.view(), None)
.map_err(|e| format!("MultinomialFamily directional H assembly: {e}"))?;
Ok(Some(dh))
}
fn exact_newton_joint_hessiansecond_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let eta = self.collect_eta_matrix(block_states)?;
let d2h_fisher =
self.second_directional_fisher_jet(eta.view(), d_beta_u_flat, d_beta_v_flat)?;
let d2h = dense_block_xtwx(self.design.view(), d2h_fisher.view(), None)
.map_err(|e| format!("MultinomialFamily second directional H assembly: {e}"))?;
Ok(Some(d2h))
}
}
struct MultinomialHessianWorkspace {
family: MultinomialFamily,
block_states: Vec<ParameterBlockState>,
probs: Array2<f64>,
}
impl ExactNewtonJointHessianWorkspace for MultinomialHessianWorkspace {
fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
self.family.exact_newton_joint_hessian(&self.block_states)
}
fn hessian_matvec_available(&self) -> bool {
true
}
fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
let mut out = Array1::<f64>::zeros(self.family.beta_flat_dim());
self.family
.hessian_matvec_into_with_probs(self.probs.view(), v, &mut out)?;
Ok(Some(out))
}
fn hessian_matvec_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<bool, String> {
self.family
.hessian_matvec_into_with_probs(self.probs.view(), v, out)?;
Ok(true)
}
fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
Ok(Some(
self.family.hessian_diagonal_with_probs(self.probs.view()),
))
}
fn directional_derivative(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.family
.exact_newton_joint_hessian_directional_derivative(&self.block_states, d_beta_flat)
}
fn second_directional_derivative(
&self,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.family
.exact_newton_joint_hessiansecond_directional_derivative(
&self.block_states,
d_beta_u,
d_beta_v,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn toy_family(n_obs: usize, p: usize, k: usize) -> MultinomialFamily {
let y = {
let mut y = Array2::<f64>::zeros((n_obs, k));
for i in 0..n_obs {
y[[i, i % k]] = 1.0;
}
y
};
let weights = Array1::<f64>::ones(n_obs);
let design = Arc::new(Array2::<f64>::from_shape_fn((n_obs, p), |(i, j)| {
((i + j + 1) as f64).sin()
}));
let penalties = Arc::new(vec![crate::custom_family::PenaltyMatrix::Dense(
Array2::<f64>::from_shape_fn((p, p), |(i, j)| if i == j { 1.0 } else { 0.0 }),
)]);
let nullspace_dims = Arc::new(vec![0usize]);
MultinomialFamily::new(y, weights, k, design, penalties, nullspace_dims)
.expect("toy MultinomialFamily must construct")
}
#[test]
fn block_specs_have_one_per_active_class_in_order() {
let family = toy_family(8, 3, 4);
let specs = family.build_block_specs();
assert_eq!(specs.len(), 3, "expected K-1 = 3 active blocks for K=4");
for (a, spec) in specs.iter().enumerate() {
assert_eq!(spec.name, format!("class_{a}"));
}
}
#[test]
fn gauge_priority_is_strictly_decreasing_in_class_index() {
let family = toy_family(8, 3, 5);
let specs = family.build_block_specs();
for window in specs.windows(2) {
assert!(
window[0].gauge_priority > window[1].gauge_priority,
"class_{} priority {} must exceed class_{} priority {}",
window[0].name,
window[0].gauge_priority,
window[1].name,
window[1].gauge_priority,
);
}
}
#[test]
fn block_specs_share_design_shape_with_family() {
let family = toy_family(8, 3, 4);
let specs = family.build_block_specs();
let (n, p) = (family.design.nrows(), family.design.ncols());
for spec in &specs {
assert_eq!(spec.design.nrows(), n);
assert_eq!(spec.design.ncols(), p);
}
}
#[test]
fn each_block_carries_the_full_per_term_penalty_list() {
let single = toy_family(6, 4, 3);
for spec in &single.build_block_specs() {
assert_eq!(spec.penalties.len(), 1);
assert_eq!(spec.initial_log_lambdas.len(), 1);
assert_eq!(spec.nullspace_dims.len(), 1);
}
let p = 5;
let k = 4;
let n_terms = 3;
let n_obs = 9;
let y = {
let mut y = Array2::<f64>::zeros((n_obs, k));
for i in 0..n_obs {
y[[i, i % k]] = 1.0;
}
y
};
let weights = Array1::<f64>::ones(n_obs);
let design = Arc::new(Array2::<f64>::from_shape_fn((n_obs, p), |(i, j)| {
((i + j + 1) as f64).cos()
}));
let penalties = Arc::new(
(0..n_terms)
.map(|t| {
crate::custom_family::PenaltyMatrix::Dense(Array2::<f64>::from_shape_fn(
(p, p),
|(i, j)| {
if i == j { (t + 1) as f64 } else { 0.0 }
},
))
})
.collect::<Vec<_>>(),
);
let nullspace_dims = Arc::new(vec![0usize; n_terms]);
let multi = MultinomialFamily::new(y, weights, k, design, penalties, nullspace_dims)
.expect("multi-term MultinomialFamily must construct");
let specs = multi.build_block_specs();
assert_eq!(specs.len(), k - 1, "one block per active class");
for spec in &specs {
assert_eq!(
spec.penalties.len(),
n_terms,
"each block must carry the full per-term penalty list (#561)"
);
assert_eq!(
spec.initial_log_lambdas.len(),
n_terms,
"each block must carry one independent λ per smooth term (#561)"
);
assert_eq!(spec.nullspace_dims.len(), n_terms);
}
}
#[test]
fn collect_eta_matrix_rejects_wrong_block_count() {
let family = toy_family(4, 2, 3);
let single = vec![ParameterBlockState {
beta: Array1::<f64>::zeros(2),
eta: Array1::<f64>::zeros(4),
}];
let err = family
.collect_eta_matrix(&single)
.expect_err("wrong block count must error");
assert!(err.contains("expects 2 blocks"));
}
#[test]
fn evaluate_uniform_eta_zero_matches_uniform_softmax() {
let family = toy_family(5, 2, 3);
let p = family.design.ncols();
let m = family.active_classes();
let n = family.weights.len();
let block_states: Vec<ParameterBlockState> = (0..m)
.map(|_| ParameterBlockState {
beta: Array1::<f64>::zeros(p),
eta: Array1::<f64>::zeros(n),
})
.collect();
let eval = family
.evaluate(&block_states)
.expect("baseline evaluate must succeed at β = 0");
let expected = (n as f64) * (1.0 / (family.total_classes as f64)).ln();
let diff = (eval.log_likelihood - expected).abs();
assert!(
diff < 1.0e-10,
"baseline log-lik {} != {}",
eval.log_likelihood,
expected,
);
assert_eq!(eval.blockworking_sets.len(), m);
}
#[test]
fn directional_fisher_jet_along_zero_vanishes() {
let family = toy_family(4, 2, 3);
let p = family.design.ncols();
let m = family.active_classes();
let n = family.weights.len();
let eta = Array2::<f64>::zeros((n, m));
let d_beta = Array1::<f64>::zeros(m * p);
let jet = family
.directional_fisher_jet(eta.view(), &d_beta)
.expect("zero direction must be valid");
for &v in jet.iter() {
assert!(v.abs() < 1.0e-14, "expected zero kernel, got {v}");
}
}
#[test]
fn beta_flat_dim_equals_active_classes_times_p() {
let family = toy_family(3, 5, 4);
assert_eq!(family.beta_flat_dim(), 3 * 5);
}
#[test]
fn matrix_free_matvec_matches_dense_hessian_dot() {
let family = toy_family(7, 3, 4);
let p = family.design.ncols();
let m = family.active_classes();
let n = family.weights.len();
let design = family.design.view();
let block_states: Vec<ParameterBlockState> = (0..m)
.map(|a| {
let beta =
Array1::<f64>::from_shape_fn(p, |i| 0.3 * ((a + 1) as f64) - 0.1 * (i as f64));
let eta = Array1::<f64>::from_shape_fn(n, |row| {
(0..p).map(|i| design[[row, i]] * beta[i]).sum()
});
ParameterBlockState { beta, eta }
})
.collect();
let specs = family.build_block_specs();
let ws = family
.exact_newton_joint_hessian_workspace(&block_states, &specs)
.expect("workspace build must succeed")
.expect("workspace must be present");
let dense = family
.exact_newton_joint_hessian(&block_states)
.expect("dense Hessian must build")
.expect("dense Hessian must be present");
for seed in 0..(m * p) {
let v = Array1::<f64>::from_shape_fn(m * p, |i| {
if i == seed {
1.0
} else {
0.07 * ((i + 1) as f64).cos()
}
});
let mf = ws
.hessian_matvec(&v)
.expect("matvec must succeed")
.expect("matvec must be present");
let dv = dense.dot(&v);
for (a, b) in mf.iter().zip(dv.iter()) {
assert!(
(a - b).abs() < 1.0e-9,
"matrix-free matvec {a} != dense {b}"
);
}
let mut into = Array1::<f64>::from_elem(m * p, f64::NAN);
let wrote = ws
.hessian_matvec_into(&v, &mut into)
.expect("matvec_into must succeed");
assert!(wrote, "matvec_into must report it wrote");
for (a, b) in into.iter().zip(mf.iter()) {
assert!((a - b).abs() < 1.0e-12, "matvec_into {a} != matvec {b}");
}
}
let mf_diag = ws
.hessian_diagonal()
.expect("diagonal must succeed")
.expect("diagonal must be present");
let dense_diag = dense.diag();
for (a, b) in mf_diag.iter().zip(dense_diag.iter()) {
assert!((a - b).abs() < 1.0e-9, "matrix-free diag {a} != dense {b}");
}
}
#[test]
fn parameter_names_emit_one_label_per_active_class() {
let family = toy_family(2, 1, 4);
let names = family.parameter_names();
assert_eq!(names, vec!["class_0", "class_1", "class_2"]);
assert_eq!(family.parameter_links().len(), names.len());
}
#[test]
fn new_rejects_k_less_than_two() {
let n = 3;
let y = array![[1.0], [1.0], [1.0]];
let w = Array1::<f64>::ones(n);
let x = Arc::new(Array2::<f64>::ones((n, 1)));
let zero = Array2::<f64>::zeros((1, 1));
let s = Arc::new(vec![crate::custom_family::PenaltyMatrix::Dense(zero)]);
let nd = Arc::new(vec![0usize]);
let err = MultinomialFamily::new(y, w, 1, x, s, nd).expect_err("K = 1 must be rejected");
assert!(err.contains("K"));
}
fn family_with_weights(
n_obs: usize,
p: usize,
k: usize,
weights: Array1<f64>,
) -> MultinomialFamily {
let y = {
let mut y = Array2::<f64>::zeros((n_obs, k));
for i in 0..n_obs {
y[[i, (3 * i + 1) % k]] = 1.0;
}
y
};
let design = Arc::new(Array2::<f64>::from_shape_fn((n_obs, p), |(i, j)| {
0.7 * ((i as f64 + 1.0) * 0.31 + (j as f64) * 0.53).sin() - 0.2 * (j as f64)
}));
let penalties = Arc::new(vec![crate::custom_family::PenaltyMatrix::Dense(
Array2::<f64>::from_shape_fn((p, p), |(i, j)| if i == j { 1.0 } else { 0.0 }),
)]);
let nullspace_dims = Arc::new(vec![0usize]);
MultinomialFamily::new(y, weights, k, design, penalties, nullspace_dims)
.expect("family_with_weights must construct")
}
fn states_at_betas(
family: &MultinomialFamily,
betas: &[Array1<f64>],
) -> Vec<ParameterBlockState> {
let x = family.design.view();
betas
.iter()
.map(|b| ParameterBlockState {
beta: b.clone(),
eta: x.dot(b),
})
.collect()
}
fn sample_betas(m: usize, p: usize, scale: f64) -> Vec<Array1<f64>> {
(0..m)
.map(|a| {
Array1::from_shape_fn(p, |i| {
scale * (0.41 * (a as f64 + 1.0) - 0.23 * (i as f64) + 0.13).sin()
})
})
.collect()
}
fn neglogl_grad(family: &MultinomialFamily, states: &[ParameterBlockState]) -> Array1<f64> {
let eta = family.collect_eta_matrix(states).expect("eta collect");
let probs = family.row_probabilities(eta.view());
let x = family.design.view();
let n = family.weights.len();
let p = family.design.ncols();
let m = family.active_classes();
let mut g = Array1::<f64>::zeros(m * p);
for a in 0..m {
for i in 0..p {
let mut acc = 0.0_f64;
for row in 0..n {
acc += x[[row, i]]
* family.weights[row]
* (probs[[row, a]] - family.y_one_hot[[row, a]]);
}
g[a * p + i] = acc;
}
}
g
}
fn perturb(betas: &[Array1<f64>], v: &Array1<f64>, factor: f64) -> Vec<Array1<f64>> {
let p = betas[0].len();
betas
.iter()
.enumerate()
.map(|(a, b)| Array1::from_shape_fn(p, |i| b[i] + factor * v[a * p + i]))
.collect()
}
#[test]
fn matrix_free_matvec_matches_dense_across_directions() {
let n = 13;
let p = 4;
let k = 4;
let family = family_with_weights(
n,
p,
k,
Array1::from_shape_fn(n, |i| 0.5 + 0.5 * ((i as f64) * 0.37).cos().abs()),
);
let m = family.active_classes();
let total = m * p;
let states = states_at_betas(&family, &sample_betas(m, p, 0.8));
let specs = family.build_block_specs();
let ws = family
.exact_newton_joint_hessian_workspace(&states, &specs)
.expect("workspace build")
.expect("workspace present");
let dense = ws.hessian_dense().expect("dense").expect("dense present");
for seed in 0..8usize {
let v = Array1::from_shape_fn(total, |idx| {
((seed * 31 + idx * 17 + 5) as f64 * 0.123).cos()
});
let mf = ws.hessian_matvec(&v).expect("matvec").expect("matvec some");
let dv = dense.dot(&v);
let mut max_abs = 0.0_f64;
let mut scale = 1.0e-300_f64;
for idx in 0..total {
max_abs = max_abs.max((mf[idx] - dv[idx]).abs());
scale = scale.max(dv[idx].abs());
}
assert!(
max_abs <= 1.0e-10 * scale + 1.0e-13,
"seed {seed}: matrix-free matvec deviates from dense by {max_abs} (scale {scale})"
);
}
}
#[test]
fn matrix_free_matvec_does_not_allocate_dense_but_matches_at_extreme_eta() {
let n = 9;
let p = 3;
let k = 5;
let family = family_with_weights(n, p, k, Array1::<f64>::ones(n));
let m = family.active_classes();
let total = m * p;
let states = states_at_betas(&family, &sample_betas(m, p, 12.0));
let specs = family.build_block_specs();
let ws = family
.exact_newton_joint_hessian_workspace(&states, &specs)
.expect("workspace build")
.expect("workspace present");
let dense = ws.hessian_dense().expect("dense").expect("dense present");
let v = Array1::from_shape_fn(total, |idx| ((idx as f64) * 0.91 - 1.0).sin());
let mf = ws.hessian_matvec(&v).expect("matvec").expect("matvec some");
let dv = dense.dot(&v);
let mut max_abs = 0.0_f64;
let mut scale = 1.0e-300_f64;
for idx in 0..total {
assert!(mf[idx].is_finite(), "matvec entry {idx} not finite");
max_abs = max_abs.max((mf[idx] - dv[idx]).abs());
scale = scale.max(dv[idx].abs());
}
assert!(
max_abs <= 1.0e-10 * scale + 1.0e-13,
"extreme-η matvec deviates from dense by {max_abs} (scale {scale})"
);
}
#[test]
fn matrix_free_matvec_handles_zero_weight_rows() {
let n = 10;
let p = 3;
let k = 3;
let mut w = Array1::<f64>::ones(n);
w[2] = 0.0;
w[5] = 0.0;
w[9] = 0.0;
let family = family_with_weights(n, p, k, w);
let m = family.active_classes();
let total = m * p;
let states = states_at_betas(&family, &sample_betas(m, p, 0.6));
let specs = family.build_block_specs();
let ws = family
.exact_newton_joint_hessian_workspace(&states, &specs)
.expect("workspace build")
.expect("workspace present");
let dense = ws.hessian_dense().expect("dense").expect("dense present");
let v = Array1::from_shape_fn(total, |idx| (idx as f64 + 0.5).cos());
let mf = ws.hessian_matvec(&v).expect("matvec").expect("matvec some");
let dv = dense.dot(&v);
let mut max_abs = 0.0_f64;
let mut scale = 1.0e-300_f64;
for idx in 0..total {
max_abs = max_abs.max((mf[idx] - dv[idx]).abs());
scale = scale.max(dv[idx].abs());
}
assert!(
max_abs <= 1.0e-10 * scale + 1.0e-13,
"zero-weight matvec deviates from dense by {max_abs} (scale {scale})"
);
}
#[test]
fn matrix_free_matvec_binary_k_equals_two() {
let n = 7;
let p = 3;
let k = 2;
let family = family_with_weights(n, p, k, Array1::<f64>::ones(n));
let m = family.active_classes();
assert_eq!(m, 1);
let total = m * p;
let states = states_at_betas(&family, &sample_betas(m, p, 1.1));
let specs = family.build_block_specs();
let ws = family
.exact_newton_joint_hessian_workspace(&states, &specs)
.expect("workspace build")
.expect("workspace present");
let dense = ws.hessian_dense().expect("dense").expect("dense present");
let v = Array1::from_shape_fn(total, |idx| (idx as f64 * 0.7 + 0.2).sin());
let mf = ws.hessian_matvec(&v).expect("matvec").expect("matvec some");
let dv = dense.dot(&v);
for idx in 0..total {
assert!(
(mf[idx] - dv[idx]).abs() <= 1.0e-12 * (1.0 + dv[idx].abs()),
"binary matvec entry {idx}: {} vs {}",
mf[idx],
dv[idx]
);
}
}
#[test]
fn matrix_free_matvec_into_matches_owned_return() {
let n = 8;
let p = 3;
let k = 4;
let family = family_with_weights(n, p, k, Array1::<f64>::ones(n));
let m = family.active_classes();
let total = m * p;
let states = states_at_betas(&family, &sample_betas(m, p, 0.9));
let specs = family.build_block_specs();
let ws = family
.exact_newton_joint_hessian_workspace(&states, &specs)
.expect("workspace build")
.expect("workspace present");
let v = Array1::from_shape_fn(total, |idx| (idx as f64 * 1.7 - 0.3).cos());
let owned = ws.hessian_matvec(&v).expect("matvec").expect("matvec some");
let mut out = Array1::from_elem(total, 7.0_f64);
let wrote = ws.hessian_matvec_into(&v, &mut out).expect("matvec_into");
assert!(wrote, "matvec_into must report it wrote a result");
assert_eq!(out, owned, "into-variant must match owned return bitwise");
}
#[test]
fn matrix_free_diagonal_is_bit_identical_to_dense_diag() {
let n = 11;
let p = 4;
let k = 4;
let family = family_with_weights(
n,
p,
k,
Array1::from_shape_fn(n, |i| 0.25 + (i as f64 % 3.0)),
);
let m = family.active_classes();
let total = m * p;
let states = states_at_betas(&family, &sample_betas(m, p, 0.7));
let specs = family.build_block_specs();
let ws = family
.exact_newton_joint_hessian_workspace(&states, &specs)
.expect("workspace build")
.expect("workspace present");
let dense = ws.hessian_dense().expect("dense").expect("dense present");
let diag = ws
.hessian_diagonal()
.expect("diagonal")
.expect("diagonal some");
for idx in 0..total {
assert_eq!(
diag[idx],
dense[[idx, idx]],
"matrix-free diagonal entry {idx} must equal dense diagonal bit-for-bit"
);
}
}
#[test]
fn matrix_free_matvec_matches_gradient_finite_difference() {
let n = 12;
let p = 3;
let k = 4;
let family = family_with_weights(
n,
p,
k,
Array1::from_shape_fn(n, |i| 0.4 + 0.3 * ((i as f64) * 0.6).sin().abs()),
);
let m = family.active_classes();
let total = m * p;
let betas = sample_betas(m, p, 0.5);
let states = states_at_betas(&family, &betas);
let specs = family.build_block_specs();
let ws = family
.exact_newton_joint_hessian_workspace(&states, &specs)
.expect("workspace build")
.expect("workspace present");
let v = Array1::from_shape_fn(total, |idx| 0.5 * ((idx as f64 * 1.3 + 0.7).sin()));
let hv = ws.hessian_matvec(&v).expect("matvec").expect("matvec some");
let eps = 1.0e-6;
let g_plus = neglogl_grad(
&family,
&states_at_betas(&family, &perturb(&betas, &v, eps)),
);
let g_minus = neglogl_grad(
&family,
&states_at_betas(&family, &perturb(&betas, &v, -eps)),
);
let mut max_abs = 0.0_f64;
let mut scale = 1.0e-300_f64;
for idx in 0..total {
let fd = (g_plus[idx] - g_minus[idx]) / (2.0 * eps);
max_abs = max_abs.max((hv[idx] - fd).abs());
scale = scale.max(fd.abs());
}
assert!(
max_abs <= 1.0e-5 * scale + 1.0e-7,
"matvec vs gradient finite-difference deviates by {max_abs} (scale {scale})"
);
}
}