use std::ops::Range;
use ndarray::{Array2, s};
use crate::smooth::{BlockwisePenalty, PenaltyStructureHint};
pub use gam_problem::CoefficientPriorMean;
pub use gam_problem::EstimationError;
#[derive(Clone)]
pub enum PenaltySpec {
Block {
local: Array2<f64>,
col_range: Range<usize>,
prior_mean: CoefficientPriorMean,
structure_hint: Option<PenaltyStructureHint>,
op: Option<std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>>,
},
Dense(Array2<f64>),
DenseWithMean {
matrix: Array2<f64>,
prior_mean: CoefficientPriorMean,
},
}
impl std::fmt::Debug for PenaltySpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PenaltySpec::Block {
local,
col_range,
prior_mean,
structure_hint,
op,
} => f
.debug_struct("Block")
.field(
"local",
&format_args!("{}×{}", local.nrows(), local.ncols()),
)
.field("col_range", col_range)
.field("prior_mean", prior_mean)
.field("structure_hint", structure_hint)
.field("op", &op.as_ref().map(|o| o.dim()))
.finish(),
PenaltySpec::Dense(m) => f
.debug_tuple("Dense")
.field(&format_args!("{}×{}", m.nrows(), m.ncols()))
.finish(),
PenaltySpec::DenseWithMean { matrix, prior_mean } => f
.debug_struct("DenseWithMean")
.field(
"matrix",
&format_args!("{}×{}", matrix.nrows(), matrix.ncols()),
)
.field("prior_mean", prior_mean)
.finish(),
}
}
}
impl PenaltySpec {
pub fn col_range(&self, p: usize) -> Range<usize> {
match self {
PenaltySpec::Block { col_range, .. } => col_range.clone(),
PenaltySpec::Dense(m) => {
assert_eq!(m.ncols(), p);
0..p
}
PenaltySpec::DenseWithMean { matrix, .. } => {
assert_eq!(matrix.ncols(), p);
0..p
}
}
}
pub fn op(&self) -> Option<&std::sync::Arc<dyn crate::analytic_penalties::PenaltyOp>> {
match self {
PenaltySpec::Block { op, .. } => op.as_ref(),
PenaltySpec::Dense(_) | PenaltySpec::DenseWithMean { .. } => None,
}
}
pub fn from_blockwise(bp: BlockwisePenalty) -> Self {
PenaltySpec::Block {
local: bp.local,
col_range: bp.col_range,
prior_mean: bp.prior_mean,
structure_hint: bp.structure_hint,
op: bp.op,
}
}
pub fn from_blockwise_ref(bp: &BlockwisePenalty) -> Self {
PenaltySpec::Block {
local: bp.local.clone(),
col_range: bp.col_range.clone(),
prior_mean: bp.prior_mean.clone(),
structure_hint: bp.structure_hint.clone(),
op: bp.op.clone(),
}
}
pub fn to_dense(&self) -> Array2<f64> {
match self {
PenaltySpec::Dense(m) => m.clone(),
PenaltySpec::DenseWithMean { matrix, .. } => matrix.clone(),
PenaltySpec::Block {
local, col_range, ..
} => {
let p = col_range.end.max(local.nrows());
let mut out = Array2::zeros((p, p));
out.slice_mut(s![col_range.clone(), col_range.clone()])
.assign(local);
out
}
}
}
pub fn to_global(&self, p_total: usize) -> Array2<f64> {
match self {
PenaltySpec::Dense(m) => {
assert_eq!(m.nrows(), p_total);
m.clone()
}
PenaltySpec::DenseWithMean { matrix, .. } => {
assert_eq!(matrix.nrows(), p_total);
matrix.clone()
}
PenaltySpec::Block {
local, col_range, ..
} => {
let mut out = Array2::zeros((p_total, p_total));
out.slice_mut(s![col_range.clone(), col_range.clone()])
.assign(local);
out
}
}
}
}
pub fn validate_penalty_spec_shape(
idx: usize,
spec: &PenaltySpec,
p: usize,
context: &str,
) -> Result<(), EstimationError> {
match spec {
PenaltySpec::Block {
local, col_range, ..
} => {
let bd = col_range.len();
if local.nrows() != bd || local.ncols() != bd {
crate::bail_invalid_estim!(
"{context}: block penalty {idx} local matrix must be {bd}x{bd}, got {}x{}",
local.nrows(),
local.ncols()
);
}
if col_range.end > p {
crate::bail_invalid_estim!(
"{context}: block penalty {idx} col_range {}..{} exceeds p={p}",
col_range.start,
col_range.end
);
}
}
PenaltySpec::Dense(m) => {
if m.nrows() != p || m.ncols() != p {
crate::bail_invalid_estim!(
"{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
m.nrows(),
m.ncols()
);
}
}
PenaltySpec::DenseWithMean { matrix, .. } => {
if matrix.nrows() != p || matrix.ncols() != p {
crate::bail_invalid_estim!(
"{context}: dense penalty {idx} must be {p}x{p}, got {}x{}",
matrix.nrows(),
matrix.ncols()
);
}
}
}
Ok(())
}