use ndarray::{Array1, ArrayView1, ArrayView2};
use crate::estimate::EstimationError;
fn validate_density(pred_density: ArrayView2<'_, f64>) -> Result<(usize, usize), EstimationError> {
let n = pred_density.nrows();
let t = pred_density.ncols();
if n == 0 || t == 0 {
return Err(EstimationError::InvalidInput(
"stacking requires a non-empty (points × candidates) predictive-density matrix"
.to_string(),
));
}
for i in 0..n {
let mut row_sum = 0.0;
for c in 0..t {
let v = pred_density[[i, c]];
if !(v.is_finite() && v >= 0.0) {
return Err(EstimationError::InvalidInput(format!(
"stacking: predictive density[{i},{c}] must be finite and non-negative; got {v}"
)));
}
row_sum += v;
}
if row_sum <= 0.0 {
return Err(EstimationError::InvalidInput(format!(
"stacking: point {i} has zero predictive density under every candidate"
)));
}
}
Ok((n, t))
}
pub fn mean_log_score(pred_density: ArrayView2<'_, f64>, weights: ArrayView1<'_, f64>) -> f64 {
let n = pred_density.nrows();
let t = pred_density.ncols();
let mut acc = 0.0;
for i in 0..n {
let mut mix = 0.0;
for c in 0..t {
mix += weights[c] * pred_density[[i, c]];
}
acc += mix.max(f64::MIN_POSITIVE).ln();
}
acc / n as f64
}
pub fn stacking_weights(
pred_density: ArrayView2<'_, f64>,
max_iter: usize,
tol: f64,
) -> Result<Array1<f64>, EstimationError> {
let (n, t) = validate_density(pred_density)?;
let mut w = Array1::<f64>::from_elem(t, 1.0 / t as f64);
for _ in 0..max_iter {
let mut numer = Array1::<f64>::zeros(t);
for i in 0..n {
let mut mix = 0.0;
for c in 0..t {
mix += w[c] * pred_density[[i, c]];
}
let inv = 1.0 / mix.max(f64::MIN_POSITIVE);
for c in 0..t {
numer[c] += pred_density[[i, c]] * inv;
}
}
let mut max_delta = 0.0_f64;
for c in 0..t {
let updated = w[c] * numer[c] / n as f64;
max_delta = max_delta.max((updated - w[c]).abs());
w[c] = updated;
}
let total: f64 = w.sum();
if total > 0.0 {
w.mapv_inplace(|v| v / total);
}
if max_delta < tol {
break;
}
}
Ok(w)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn concentrates_on_a_dominant_candidate() {
let pred = Array2::from_shape_vec(
(5, 2),
vec![0.9, 0.1, 0.8, 0.2, 0.95, 0.05, 0.7, 0.3, 0.85, 0.15],
)
.unwrap();
let w = stacking_weights(pred.view(), 1000, 1e-12).unwrap();
assert!(w[0] > 0.95, "expected dominance on candidate 0, got {w:?}");
assert!((w.sum() - 1.0).abs() < 1e-10);
}
#[test]
fn beats_or_matches_the_best_single_candidate() {
let pred = Array2::from_shape_vec(
(6, 2),
vec![
0.9, 0.1, 0.8, 0.2, 0.85, 0.15, 0.1, 0.9, 0.2, 0.8, 0.15, 0.85,
],
)
.unwrap();
let w = stacking_weights(pred.view(), 2000, 1e-12).unwrap();
let stacked = mean_log_score(pred.view(), w.view());
let only0 = mean_log_score(pred.view(), Array1::from(vec![1.0, 0.0]).view());
let only1 = mean_log_score(pred.view(), Array1::from(vec![0.0, 1.0]).view());
assert!(
stacked >= only0.max(only1) - 1e-9,
"stacked {stacked} < best single {}",
only0.max(only1)
);
assert!(stacked > only0.max(only1) + 1e-3);
assert!(w[0] > 0.2 && w[1] > 0.2, "weights {w:?}");
}
#[test]
fn weights_live_on_the_simplex() {
let pred = Array2::from_shape_vec(
(4, 3),
vec![0.5, 0.3, 0.2, 0.4, 0.4, 0.2, 0.6, 0.1, 0.3, 0.2, 0.5, 0.3],
)
.unwrap();
let w = stacking_weights(pred.view(), 500, 1e-12).unwrap();
assert!((w.sum() - 1.0).abs() < 1e-10);
assert!(w.iter().all(|&v| v >= 0.0));
}
#[test]
fn rejects_degenerate_inputs() {
let empty = Array2::<f64>::zeros((0, 2));
assert!(stacking_weights(empty.view(), 10, 1e-9).is_err());
let zero_row = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 0.5, 0.5]).unwrap();
assert!(stacking_weights(zero_row.view(), 10, 1e-9).is_err());
let negative = Array2::from_shape_vec((2, 2), vec![0.5, -0.1, 0.5, 0.5]).unwrap();
assert!(stacking_weights(negative.view(), 10, 1e-9).is_err());
}
}