use super::*;
#[derive(Debug, Clone)]
pub struct ARDPenalty {
pub target: PsiSlice,
pub latent_dim: usize,
pub weight: f64,
pub weight_schedule: Option<ScalarWeightSchedule>,
pub rho_indices: Vec<usize>,
pub n_eff: f64,
}
impl ARDPenalty {
#[must_use]
pub fn new(target: PsiSlice, latent_dim: usize) -> Self {
assert!(latent_dim > 0, "ARDPenalty requires latent_dim > 0");
let n_obs = if latent_dim == 0 {
0
} else {
target.len() / latent_dim
};
let rho_indices = (0..latent_dim).collect();
Self {
target,
latent_dim,
weight: 1.0,
weight_schedule: None,
rho_indices,
n_eff: n_obs as f64,
}
}
impl_with_weight_schedule!(weight);
#[must_use = "build error must be handled"]
pub fn with_n_eff(mut self, n_eff: f64) -> Result<Self, String> {
if !(n_eff.is_finite() && n_eff >= 0.0) {
return Err(format!(
"ARDPenalty::with_n_eff requires a finite non-negative value, got {n_eff}"
));
}
self.n_eff = n_eff;
Ok(self)
}
pub fn as_blockwise(&self, global_offset: usize) -> Vec<BlockwisePenalty> {
let n_obs = self.target.len() / self.latent_dim;
let mut out = Vec::with_capacity(n_obs * self.latent_dim);
for j in 0..self.latent_dim {
for n in 0..n_obs {
let idx = global_offset + self.target.range.start + n * self.latent_dim + j;
out.push(BlockwisePenalty::ridge(idx..idx + 1, 1.0).with_op(None));
}
}
out
}
}
impl AnalyticPenalty for ARDPenalty {
fn tier(&self) -> PenaltyTier {
PenaltyTier::Psi
}
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut acc = 0.0;
for j in 0..d {
let lam_j = resolve_learnable_weight(self.weight, rho[self.rho_indices[j]]);
let mut sq = 0.0;
for n in 0..n_obs {
let v = target[n * d + j];
sq += v * v;
}
acc += 0.5 * lam_j * sq - 0.5 * self.n_eff * lam_j.ln();
}
acc
}
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut g = Array1::<f64>::zeros(target.len());
for j in 0..d {
let lam_j = resolve_learnable_weight(self.weight, rho[self.rho_indices[j]]);
for n in 0..n_obs {
g[n * d + j] = lam_j * target[n * d + j];
}
}
g
}
fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut diag = Array1::<f64>::zeros(target.len());
for j in 0..d {
let lam_j = resolve_learnable_weight(self.weight, rho[self.rho_indices[j]]);
for n in 0..n_obs {
diag[n * d + j] = lam_j;
}
}
Some(diag)
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
let d = self.latent_dim;
let n_obs = target.len() / d;
let mut out = Array1::<f64>::zeros(self.rho_count());
for j in 0..d {
let lam_j = resolve_learnable_weight(self.weight, rho[self.rho_indices[j]]);
let mut sq = 0.0;
for n in 0..n_obs {
let v = target[n * d + j];
sq += v * v;
}
out[self.rho_indices[j]] = 0.5 * lam_j * sq - 0.5 * self.n_eff;
}
out
}
fn rho_count(&self) -> usize {
self.latent_dim
}
fn name(&self) -> &str {
"ard"
}
impl_scalar_apply_schedule!(weight);
}