use crate::families::custom_family::{
BlockWorkingSet, CustomFamily, ExactNewtonJointGradientEvaluation,
ExactNewtonJointHessianWorkspace, FamilyEvaluation, ParameterBlockSpec, ParameterBlockState,
PenaltyMatrix,
};
use crate::families::gamlss::{FamilyMetadata, ParameterLink};
use crate::families::vector_response::{MultinomialLogitLikelihood, VectorLikelihood};
use crate::matrix::{DenseDesignMatrix, DesignMatrix, SymmetricMatrix};
use crate::pirls::dense_block_xtwx;
use ndarray::{Array1, Array2, Array3, ArrayView2};
use std::sync::Arc;
#[derive(Clone)]
pub struct MultinomialFamily {
pub y_one_hot: Array2<f64>,
pub weights: Array1<f64>,
pub total_classes: usize,
pub design: Arc<Array2<f64>>,
pub penalty: Arc<Array2<f64>>,
pub penalty_nullspace_dim: usize,
likelihood: MultinomialLogitLikelihood,
}
impl MultinomialFamily {
pub const fn active_classes(&self) -> usize {
self.total_classes - 1
}
pub fn parameter_names(&self) -> Vec<String> {
(0..self.active_classes())
.map(|a| format!("class_{a}"))
.collect()
}
pub fn parameter_links(&self) -> Vec<ParameterLink> {
vec![ParameterLink::Identity; self.active_classes()]
}
pub fn metadata() -> FamilyMetadata {
FamilyMetadata {
name: "multinomial_logit",
parameternames: &[],
parameter_links: &[],
}
}
pub fn new(
y_one_hot: Array2<f64>,
weights: Array1<f64>,
total_classes: usize,
design: Arc<Array2<f64>>,
penalty: Arc<Array2<f64>>,
penalty_nullspace_dim: usize,
) -> Result<Self, String> {
if total_classes < 2 {
return Err(format!(
"MultinomialFamily requires K ≥ 2 classes (got {total_classes})"
));
}
let (n, k) = y_one_hot.dim();
if k != total_classes {
return Err(format!(
"MultinomialFamily: y_one_hot has {k} columns but total_classes = {total_classes}"
));
}
if weights.len() != n {
return Err(format!(
"MultinomialFamily: weights length {} != N = {n}",
weights.len()
));
}
for (i, &v) in weights.iter().enumerate() {
if !(v.is_finite() && v >= 0.0) {
return Err(format!(
"MultinomialFamily: weights[{i}] must be finite and non-negative (got {v})"
));
}
}
if design.nrows() != n {
return Err(format!(
"MultinomialFamily: design has {} rows, expected {n}",
design.nrows()
));
}
let p = design.ncols();
if penalty.dim() != (p, p) {
return Err(format!(
"MultinomialFamily: penalty shape {:?} != (P, P) = ({p}, {p})",
penalty.dim()
));
}
for ((i, j), &v) in y_one_hot.indexed_iter() {
if !v.is_finite() {
return Err(format!(
"MultinomialFamily: y_one_hot[{i},{j}] must be finite (got {v})"
));
}
}
for ((i, j), &v) in design.indexed_iter() {
if !v.is_finite() {
return Err(format!(
"MultinomialFamily: design[{i},{j}] must be finite (got {v})"
));
}
}
for ((i, j), &v) in penalty.indexed_iter() {
if !v.is_finite() {
return Err(format!(
"MultinomialFamily: penalty[{i},{j}] must be finite (got {v})"
));
}
}
let likelihood = MultinomialLogitLikelihood::with_classes(total_classes)
.map_err(|e| format!("MultinomialFamily: {e}"))?
.with_row_weights(weights.clone())
.map_err(|e| format!("MultinomialFamily: {e}"))?;
Ok(Self {
y_one_hot,
weights,
total_classes,
design,
penalty,
penalty_nullspace_dim,
likelihood,
})
}
pub fn build_block_specs(&self) -> Vec<ParameterBlockSpec> {
let nullspace_dims = vec![self.penalty_nullspace_dim];
let m = self.active_classes();
(0..m)
.map(|a| {
let priority = 100u8.saturating_add(u8::try_from(m - a).unwrap_or(u8::MAX));
ParameterBlockSpec {
name: format!("class_{a}"),
design: DesignMatrix::Dense(DenseDesignMatrix::from(self.design.clone())),
offset: Array1::<f64>::zeros(self.design.nrows()),
penalties: vec![PenaltyMatrix::Dense((*self.penalty).clone())],
nullspace_dims: nullspace_dims.clone(),
initial_log_lambdas: Array1::<f64>::zeros(1),
initial_beta: None,
gauge_priority: priority,
jacobian_callback: None,
stacked_design: None,
stacked_offset: None,
}
})
.collect()
}
pub fn beta_flat_dim(&self) -> usize {
self.active_classes() * self.design.ncols()
}
fn collect_eta_matrix(
&self,
block_states: &[ParameterBlockState],
) -> Result<Array2<f64>, String> {
let m = self.active_classes();
if block_states.len() != m {
return Err(format!(
"MultinomialFamily expects {m} blocks (K-1), got {}",
block_states.len()
));
}
let n = self.weights.len();
let mut eta = Array2::<f64>::zeros((n, m));
for (a, state) in block_states.iter().enumerate() {
if state.eta.len() != n {
return Err(format!(
"MultinomialFamily block {a} eta length {} != N = {n}",
state.eta.len()
));
}
for row in 0..n {
eta[[row, a]] = state.eta[row];
}
}
Ok(eta)
}
fn evaluate_row_kernels(
&self,
eta: ArrayView2<'_, f64>,
) -> (f64, Array3<f64>, Array2<f64>) {
let log_lik = self.likelihood.log_lik(eta, self.y_one_hot.view());
let fisher = self.likelihood.hess_block(eta, self.y_one_hot.view());
let grad_eta_logl = self.likelihood.grad_eta(eta, self.y_one_hot.view());
(log_lik, fisher, grad_eta_logl)
}
fn assemble_block_diagonal_working_sets(
&self,
fisher: &Array3<f64>,
grad_eta_logl: &Array2<f64>,
) -> Result<Vec<BlockWorkingSet>, String> {
let n = self.weights.len();
let p = self.design.ncols();
let m = self.active_classes();
let design_view = self.design.view();
let mut sets = Vec::with_capacity(m);
for a in 0..m {
let mut grad = Array1::<f64>::zeros(p);
for i in 0..p {
let mut acc = 0.0_f64;
for row in 0..n {
acc += design_view[[row, i]] * (-grad_eta_logl[[row, a]]);
}
grad[i] = acc;
}
let mut hess = Array2::<f64>::zeros((p, p));
for row in 0..n {
let w_aa = fisher[[row, a, a]];
if w_aa == 0.0 {
continue;
}
for i in 0..p {
let xi = design_view[[row, i]];
if xi == 0.0 {
continue;
}
let scaled = w_aa * xi;
for j in 0..p {
hess[[i, j]] += scaled * design_view[[row, j]];
}
}
}
for i in 0..p {
for j in (i + 1)..p {
let avg = 0.5 * (hess[[i, j]] + hess[[j, i]]);
hess[[i, j]] = avg;
hess[[j, i]] = avg;
}
}
sets.push(BlockWorkingSet::ExactNewton {
gradient: grad,
hessian: SymmetricMatrix::Dense(hess),
});
}
Ok(sets)
}
fn assemble_joint_hessian(&self, fisher: &Array3<f64>) -> Result<Array2<f64>, String> {
dense_block_xtwx(self.design.view(), fisher.view(), None)
.map_err(|e| format!("MultinomialFamily joint Hessian assembly: {e}"))
}
fn assemble_joint_gradient(&self, grad_eta_logl: &Array2<f64>) -> Array1<f64> {
let n = self.weights.len();
let p = self.design.ncols();
let m = self.active_classes();
let design_view = self.design.view();
let mut out = Array1::<f64>::zeros(m * p);
for a in 0..m {
for i in 0..p {
let mut acc = 0.0_f64;
for row in 0..n {
acc += design_view[[row, i]] * grad_eta_logl[[row, a]];
}
out[a * p + i] = acc;
}
}
out
}
fn d_eta_from_d_beta(&self, d_beta_flat: &Array1<f64>) -> Result<Array2<f64>, String> {
let p = self.design.ncols();
let m = self.active_classes();
let n = self.design.nrows();
if d_beta_flat.len() != m * p {
return Err(format!(
"MultinomialFamily direction length {} != (K-1)·P = {}",
d_beta_flat.len(),
m * p
));
}
let mut d_eta = Array2::<f64>::zeros((n, m));
let design_view = self.design.view();
for a in 0..m {
for row in 0..n {
let mut acc = 0.0_f64;
for i in 0..p {
acc += design_view[[row, i]] * d_beta_flat[a * p + i];
}
d_eta[[row, a]] = acc;
}
}
Ok(d_eta)
}
fn row_probabilities(&self, eta: ArrayView2<'_, f64>) -> Array2<f64> {
self.likelihood.probabilities(eta)
}
fn directional_fisher_jet(
&self,
eta: ArrayView2<'_, f64>,
d_beta_flat: &Array1<f64>,
) -> Result<Array3<f64>, String> {
let n = self.weights.len();
let m = self.active_classes();
let probs_full = self.row_probabilities(eta);
let d_eta = self.d_eta_from_d_beta(d_beta_flat)?;
let mut out = Array3::<f64>::zeros((n, m, m));
let mut dp = vec![0.0_f64; m];
for row in 0..n {
let w = self.weights[row];
if w == 0.0 {
continue;
}
let mut s = 0.0_f64;
for a in 0..m {
s += probs_full[[row, a]] * d_eta[[row, a]];
}
for a in 0..m {
dp[a] = probs_full[[row, a]] * (d_eta[[row, a]] - s);
}
for a in 0..m {
let pa = probs_full[[row, a]];
out[[row, a, a]] = w * (dp[a] - 2.0 * dp[a] * pa);
for b in (a + 1)..m {
let pb = probs_full[[row, b]];
let off = w * (-(dp[a] * pb + pa * dp[b]));
out[[row, a, b]] = off;
out[[row, b, a]] = off;
}
}
}
Ok(out)
}
fn second_directional_fisher_jet(
&self,
eta: ArrayView2<'_, f64>,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Array3<f64>, String> {
let n = self.weights.len();
let m = self.active_classes();
let probs_full = self.row_probabilities(eta);
let d_eta_u = self.d_eta_from_d_beta(d_beta_u)?;
let d_eta_v = self.d_eta_from_d_beta(d_beta_v)?;
let mut out = Array3::<f64>::zeros((n, m, m));
let mut dp_u = vec![0.0_f64; m];
let mut dp_v = vec![0.0_f64; m];
let mut ddp = vec![0.0_f64; m];
for row in 0..n {
let w = self.weights[row];
if w == 0.0 {
continue;
}
let mut s_u = 0.0_f64;
let mut s_v = 0.0_f64;
for a in 0..m {
s_u += probs_full[[row, a]] * d_eta_u[[row, a]];
s_v += probs_full[[row, a]] * d_eta_v[[row, a]];
}
for a in 0..m {
let pa = probs_full[[row, a]];
dp_u[a] = pa * (d_eta_u[[row, a]] - s_u);
dp_v[a] = pa * (d_eta_v[[row, a]] - s_v);
}
let mut ds_u_dv = 0.0_f64;
for c in 0..m {
ds_u_dv += dp_v[c] * d_eta_u[[row, c]];
}
for a in 0..m {
let pa = probs_full[[row, a]];
ddp[a] = dp_v[a] * (d_eta_u[[row, a]] - s_u) + pa * (-ds_u_dv);
}
for a in 0..m {
let pa = probs_full[[row, a]];
out[[row, a, a]] = w
* (ddp[a]
- 2.0 * ddp[a] * pa
- 2.0 * dp_u[a] * dp_v[a]);
for b in (a + 1)..m {
let pb = probs_full[[row, b]];
let off = w
* (-(ddp[a] * pb
+ dp_u[a] * dp_v[b]
+ dp_v[a] * dp_u[b]
+ pa * ddp[b]));
out[[row, a, b]] = off;
out[[row, b, a]] = off;
}
}
}
Ok(out)
}
}
impl CustomFamily for MultinomialFamily {
fn exact_newton_joint_hessian_beta_dependent(&self) -> bool {
true
}
fn has_explicit_joint_hessian(&self) -> bool {
true
}
fn requires_joint_outer_hyper_path(&self) -> bool {
true
}
fn coefficient_hessian_cost(&self, specs: &[ParameterBlockSpec]) -> u64 {
crate::custom_family::joint_coupled_coefficient_hessian_cost(
self.weights.len() as u64,
specs,
)
}
fn evaluate(&self, block_states: &[ParameterBlockState]) -> Result<FamilyEvaluation, String> {
let eta = self.collect_eta_matrix(block_states)?;
let (log_lik, fisher, grad_eta_logl) = self.evaluate_row_kernels(eta.view());
let working_sets = self.assemble_block_diagonal_working_sets(&fisher, &grad_eta_logl)?;
Ok(FamilyEvaluation {
log_likelihood: log_lik,
blockworking_sets: working_sets,
})
}
fn log_likelihood_only(&self, block_states: &[ParameterBlockState]) -> Result<f64, String> {
let eta = self.collect_eta_matrix(block_states)?;
Ok(self.likelihood.log_lik(eta.view(), self.y_one_hot.view()))
}
fn exact_newton_joint_hessian(
&self,
block_states: &[ParameterBlockState],
) -> Result<Option<Array2<f64>>, String> {
let eta = self.collect_eta_matrix(block_states)?;
let (_, fisher, _) = self.evaluate_row_kernels(eta.view());
let hessian = self.assemble_joint_hessian(&fisher)?;
Ok(Some(hessian))
}
fn exact_newton_joint_gradient_evaluation(
&self,
block_states: &[ParameterBlockState],
block_specs: &[ParameterBlockSpec],
) -> Result<Option<ExactNewtonJointGradientEvaluation>, String> {
assert!(block_specs.len() <= isize::MAX as usize);
let eta = self.collect_eta_matrix(block_states)?;
let log_lik = self.likelihood.log_lik(eta.view(), self.y_one_hot.view());
let grad_eta_logl = self.likelihood.grad_eta(eta.view(), self.y_one_hot.view());
let gradient = self.assemble_joint_gradient(&grad_eta_logl);
Ok(Some(ExactNewtonJointGradientEvaluation {
log_likelihood: log_lik,
gradient,
}))
}
fn exact_newton_joint_hessian_workspace(
&self,
block_states: &[ParameterBlockState],
block_specs: &[ParameterBlockSpec],
) -> Result<Option<Arc<dyn ExactNewtonJointHessianWorkspace>>, String> {
assert!(block_specs.len() <= isize::MAX as usize);
let eta = self.collect_eta_matrix(block_states)?;
let probs = self.row_probabilities(eta.view());
Ok(Some(Arc::new(MultinomialHessianWorkspace {
family: self.clone(),
block_states: block_states.to_vec(),
probs,
})))
}
fn exact_newton_joint_hessian_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let eta = self.collect_eta_matrix(block_states)?;
let dh_fisher = self.directional_fisher_jet(eta.view(), d_beta_flat)?;
let dh = dense_block_xtwx(self.design.view(), dh_fisher.view(), None)
.map_err(|e| format!("MultinomialFamily directional H assembly: {e}"))?;
Ok(Some(dh))
}
fn exact_newton_joint_hessiansecond_directional_derivative(
&self,
block_states: &[ParameterBlockState],
d_beta_u_flat: &Array1<f64>,
d_beta_v_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
let eta = self.collect_eta_matrix(block_states)?;
let d2h_fisher =
self.second_directional_fisher_jet(eta.view(), d_beta_u_flat, d_beta_v_flat)?;
let d2h = dense_block_xtwx(self.design.view(), d2h_fisher.view(), None)
.map_err(|e| format!("MultinomialFamily second directional H assembly: {e}"))?;
Ok(Some(d2h))
}
}
struct MultinomialHessianWorkspace {
family: MultinomialFamily,
block_states: Vec<ParameterBlockState>,
probs: Array2<f64>,
}
impl MultinomialHessianWorkspace {
fn matvec_accumulate(&self, v: &Array1<f64>, out: &mut Array1<f64>) {
let m = self.family.active_classes();
let p = self.family.design.ncols();
let n = self.family.weights.len();
let design = self.family.design.view();
let weights = self.family.weights.view();
out.fill(0.0);
let mut z = vec![0.0_f64; m];
for row in 0..n {
let w_n = weights[row];
if w_n == 0.0 {
continue;
}
let mut s = 0.0_f64;
for (a, z_a) in z.iter_mut().enumerate() {
let base = a * p;
let mut acc = 0.0_f64;
for i in 0..p {
acc += design[[row, i]] * v[base + i];
}
*z_a = acc;
s += self.probs[[row, a]] * acc;
}
for (a, &z_a) in z.iter().enumerate() {
let r_a = w_n * self.probs[[row, a]] * (z_a - s);
if r_a == 0.0 {
continue;
}
let base = a * p;
for i in 0..p {
out[base + i] += design[[row, i]] * r_a;
}
}
}
}
}
impl ExactNewtonJointHessianWorkspace for MultinomialHessianWorkspace {
fn hessian_dense(&self) -> Result<Option<Array2<f64>>, String> {
self.family.exact_newton_joint_hessian(&self.block_states)
}
fn hessian_matvec_available(&self) -> bool {
true
}
fn hessian_matvec(&self, v: &Array1<f64>) -> Result<Option<Array1<f64>>, String> {
let total = self.family.beta_flat_dim();
if v.len() != total {
return Err(format!(
"MultinomialHessianWorkspace::hessian_matvec: v len {} != (K-1)·P = {total}",
v.len()
));
}
let mut out = Array1::<f64>::zeros(total);
self.matvec_accumulate(v, &mut out);
Ok(Some(out))
}
fn hessian_matvec_into(&self, v: &Array1<f64>, out: &mut Array1<f64>) -> Result<bool, String> {
let total = self.family.beta_flat_dim();
if v.len() != total {
return Err(format!(
"MultinomialHessianWorkspace::hessian_matvec_into: v len {} != (K-1)·P = {total}",
v.len()
));
}
if out.len() != total {
return Err(format!(
"MultinomialHessianWorkspace::hessian_matvec_into: out len {} != (K-1)·P = {total}",
out.len()
));
}
self.matvec_accumulate(v, out);
Ok(true)
}
fn hessian_diagonal(&self) -> Result<Option<Array1<f64>>, String> {
let m = self.family.active_classes();
let p = self.family.design.ncols();
let n = self.family.weights.len();
let design = self.family.design.view();
let weights = self.family.weights.view();
let mut diag = Array1::<f64>::zeros(m * p);
for row in 0..n {
let w_n = weights[row];
if w_n == 0.0 {
continue;
}
for a in 0..m {
let p_a = self.probs[[row, a]];
let w_aa = w_n * p_a * (1.0 - p_a);
if w_aa == 0.0 {
continue;
}
let base = a * p;
for i in 0..p {
let xi = design[[row, i]];
diag[base + i] += w_aa * xi * xi;
}
}
}
Ok(Some(diag))
}
fn directional_derivative(
&self,
d_beta_flat: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.family
.exact_newton_joint_hessian_directional_derivative(&self.block_states, d_beta_flat)
}
fn second_directional_derivative(
&self,
d_beta_u: &Array1<f64>,
d_beta_v: &Array1<f64>,
) -> Result<Option<Array2<f64>>, String> {
self.family
.exact_newton_joint_hessiansecond_directional_derivative(
&self.block_states,
d_beta_u,
d_beta_v,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn toy_family(n_obs: usize, p: usize, k: usize) -> MultinomialFamily {
let y = {
let mut y = Array2::<f64>::zeros((n_obs, k));
for i in 0..n_obs {
y[[i, i % k]] = 1.0;
}
y
};
let weights = Array1::<f64>::ones(n_obs);
let design = Arc::new(Array2::<f64>::from_shape_fn((n_obs, p), |(i, j)| {
((i + j + 1) as f64).sin()
}));
let penalty = Arc::new(Array2::<f64>::from_shape_fn((p, p), |(i, j)| {
if i == j { 1.0 } else { 0.0 }
}));
MultinomialFamily::new(y, weights, k, design, penalty, 0)
.expect("toy MultinomialFamily must construct")
}
#[test]
fn block_specs_have_one_per_active_class_in_order() {
let family = toy_family(8, 3, 4);
let specs = family.build_block_specs();
assert_eq!(specs.len(), 3, "expected K-1 = 3 active blocks for K=4");
for (a, spec) in specs.iter().enumerate() {
assert_eq!(spec.name, format!("class_{a}"));
}
}
#[test]
fn gauge_priority_is_strictly_decreasing_in_class_index() {
let family = toy_family(8, 3, 5);
let specs = family.build_block_specs();
for window in specs.windows(2) {
assert!(
window[0].gauge_priority > window[1].gauge_priority,
"class_{} priority {} must exceed class_{} priority {}",
window[0].name,
window[0].gauge_priority,
window[1].name,
window[1].gauge_priority,
);
}
}
#[test]
fn block_specs_share_design_shape_with_family() {
let family = toy_family(8, 3, 4);
let specs = family.build_block_specs();
let (n, p) = (family.design.nrows(), family.design.ncols());
for spec in &specs {
assert_eq!(spec.design.nrows(), n);
assert_eq!(spec.design.ncols(), p);
}
}
#[test]
fn each_block_carries_exactly_one_penalty_for_kronecker_form() {
let family = toy_family(6, 4, 3);
let specs = family.build_block_specs();
for spec in &specs {
assert_eq!(spec.penalties.len(), 1);
assert_eq!(spec.initial_log_lambdas.len(), 1);
}
}
#[test]
fn collect_eta_matrix_rejects_wrong_block_count() {
let family = toy_family(4, 2, 3);
let single = vec![ParameterBlockState {
beta: Array1::<f64>::zeros(2),
eta: Array1::<f64>::zeros(4),
}];
let err = family
.collect_eta_matrix(&single)
.expect_err("wrong block count must error");
assert!(err.contains("expects 2 blocks"));
}
#[test]
fn evaluate_uniform_eta_zero_matches_uniform_softmax() {
let family = toy_family(5, 2, 3);
let p = family.design.ncols();
let m = family.active_classes();
let n = family.weights.len();
let block_states: Vec<ParameterBlockState> = (0..m)
.map(|_| ParameterBlockState {
beta: Array1::<f64>::zeros(p),
eta: Array1::<f64>::zeros(n),
})
.collect();
let eval = family
.evaluate(&block_states)
.expect("baseline evaluate must succeed at β = 0");
let expected = (n as f64) * (1.0 / (family.total_classes as f64)).ln();
let diff = (eval.log_likelihood - expected).abs();
assert!(
diff < 1.0e-10,
"baseline log-lik {} != {}",
eval.log_likelihood,
expected,
);
assert_eq!(eval.blockworking_sets.len(), m);
}
#[test]
fn directional_fisher_jet_along_zero_vanishes() {
let family = toy_family(4, 2, 3);
let p = family.design.ncols();
let m = family.active_classes();
let n = family.weights.len();
let eta = Array2::<f64>::zeros((n, m));
let d_beta = Array1::<f64>::zeros(m * p);
let jet = family
.directional_fisher_jet(eta.view(), &d_beta)
.expect("zero direction must be valid");
for &v in jet.iter() {
assert!(v.abs() < 1.0e-14, "expected zero kernel, got {v}");
}
}
#[test]
fn beta_flat_dim_equals_active_classes_times_p() {
let family = toy_family(3, 5, 4);
assert_eq!(family.beta_flat_dim(), 3 * 5);
}
#[test]
fn matrix_free_matvec_matches_dense_hessian_dot() {
let family = toy_family(7, 3, 4);
let p = family.design.ncols();
let m = family.active_classes();
let n = family.weights.len();
let design = family.design.view();
let block_states: Vec<ParameterBlockState> = (0..m)
.map(|a| {
let beta = Array1::<f64>::from_shape_fn(p, |i| 0.3 * ((a + 1) as f64) - 0.1 * (i as f64));
let eta = Array1::<f64>::from_shape_fn(n, |row| {
(0..p).map(|i| design[[row, i]] * beta[i]).sum()
});
ParameterBlockState { beta, eta }
})
.collect();
let specs = family.build_block_specs();
let ws = family
.exact_newton_joint_hessian_workspace(&block_states, &specs)
.expect("workspace build must succeed")
.expect("workspace must be present");
let dense = family
.exact_newton_joint_hessian(&block_states)
.expect("dense Hessian must build")
.expect("dense Hessian must be present");
for seed in 0..(m * p) {
let v = Array1::<f64>::from_shape_fn(m * p, |i| {
if i == seed { 1.0 } else { 0.07 * ((i + 1) as f64).cos() }
});
let mf = ws
.hessian_matvec(&v)
.expect("matvec must succeed")
.expect("matvec must be present");
let dv = dense.dot(&v);
for (a, b) in mf.iter().zip(dv.iter()) {
assert!(
(a - b).abs() < 1.0e-9,
"matrix-free matvec {a} != dense {b}"
);
}
let mut into = Array1::<f64>::from_elem(m * p, f64::NAN);
let wrote = ws
.hessian_matvec_into(&v, &mut into)
.expect("matvec_into must succeed");
assert!(wrote, "matvec_into must report it wrote");
for (a, b) in into.iter().zip(mf.iter()) {
assert!((a - b).abs() < 1.0e-12, "matvec_into {a} != matvec {b}");
}
}
let mf_diag = ws
.hessian_diagonal()
.expect("diagonal must succeed")
.expect("diagonal must be present");
let dense_diag = dense.diag();
for (a, b) in mf_diag.iter().zip(dense_diag.iter()) {
assert!((a - b).abs() < 1.0e-9, "matrix-free diag {a} != dense {b}");
}
}
#[test]
fn parameter_names_emit_one_label_per_active_class() {
let family = toy_family(2, 1, 4);
let names = family.parameter_names();
assert_eq!(names, vec!["class_0", "class_1", "class_2"]);
assert_eq!(family.parameter_links().len(), names.len());
}
#[test]
fn new_rejects_k_less_than_two() {
let n = 3;
let y = array![[1.0], [1.0], [1.0]];
let w = Array1::<f64>::ones(n);
let x = Arc::new(Array2::<f64>::ones((n, 1)));
let s = Arc::new(Array2::<f64>::zeros((1, 1)));
let err =
MultinomialFamily::new(y, w, 1, x, s, 0).expect_err("K = 1 must be rejected");
assert!(err.contains("K"));
}
}