use super::*;
pub(crate) const MIN_CONDITIONAL_PRECISION: f64 = 1.0e-12;
pub const ENTROPY_LOG_PROBABILITY_FLOOR: f64 = 1e-300;
pub(crate) const IBP_PROBABILITY_CLAMP: f64 = 1.0e-12;
pub(crate) const IBP_INTERIOR_TOL: f64 = 1.0e-9;
pub(crate) const IBP_COUNT_DENOM_FLOOR: f64 = 1.0e-9;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PenaltyTier {
Beta,
Psi,
Rho,
}
#[derive(Debug, Clone)]
pub struct PsiSlice {
pub range: std::ops::Range<usize>,
pub latent_dim: Option<usize>,
}
impl PsiSlice {
#[must_use]
pub fn full(len: usize, latent_dim: Option<usize>) -> Self {
Self {
range: 0..len,
latent_dim,
}
}
pub fn len(&self) -> usize {
self.range.len()
}
pub fn is_empty(&self) -> bool {
self.range.is_empty()
}
}
pub fn resolve_learnable_weight(base_weight: f64, rho: f64) -> f64 {
const MAX_LOG_STRENGTH: f64 = 700.0;
const MIN_LOG_STRENGTH: f64 = -700.0;
if base_weight == 0.0 {
return 0.0;
}
assert!(
base_weight.is_finite() && rho.is_finite(),
"resolve_learnable_weight requires finite inputs; got base_weight={base_weight}, rho={rho}"
);
let log_strength = base_weight.abs().ln() + rho;
let clamped = log_strength.clamp(MIN_LOG_STRENGTH, MAX_LOG_STRENGTH);
clamped.exp().copysign(base_weight)
}
pub(crate) fn stable_exp_log_precision(log_alpha: f64) -> f64 {
const MAX_LOG_STRENGTH: f64 = 700.0;
const MIN_LOG_STRENGTH: f64 = -700.0;
log_alpha
.clamp(MIN_LOG_STRENGTH, MAX_LOG_STRENGTH)
.exp()
.max(f64::MIN_POSITIVE)
}
#[derive(Debug, Clone)]
pub struct ScalarWeightSchedule {
pub w_start: f64,
pub w_end: f64,
pub kind: ScheduleKind,
pub iter_count: usize,
}
impl ScalarWeightSchedule {
#[must_use = "build error must be handled"]
pub fn new(w_start: f64, w_end: f64, kind: ScheduleKind) -> Result<Self, String> {
let schedule = Self {
w_start,
w_end,
kind,
iter_count: 0,
};
schedule.validate()?;
Ok(schedule)
}
pub fn validate(&self) -> Result<(), String> {
if !(self.w_start.is_finite() && self.w_start >= 0.0) {
return Err(format!(
"ScalarWeightSchedule: w_start must be finite and non-negative; got {}",
self.w_start
));
}
if !(self.w_end.is_finite() && self.w_end >= 0.0) {
return Err(format!(
"ScalarWeightSchedule: w_end must be finite and non-negative; got {}",
self.w_end
));
}
match &self.kind {
ScheduleKind::Geometric { rate } => {
if !(rate.is_finite() && *rate > 0.0 && *rate < 1.0) {
return Err(format!(
"ScalarWeightSchedule::Geometric: rate must be in (0, 1); got {rate}"
));
}
}
ScheduleKind::Linear { steps } => {
if *steps == 0 {
return Err("ScalarWeightSchedule::Linear: steps must be positive".into());
}
}
ScheduleKind::ReciprocalIter => {}
}
Ok(())
}
pub fn current_weight(&self, iter: usize) -> f64 {
let delta = self.w_end - self.w_start;
let raw = match &self.kind {
ScheduleKind::Geometric { rate } => self.w_end - delta * rate.powf(iter as f64),
ScheduleKind::Linear { steps } => {
if iter >= *steps {
self.w_end
} else {
let frac = iter as f64 / *steps as f64;
self.w_start + frac * delta
}
}
ScheduleKind::ReciprocalIter => self.w_end - delta / (1.0 + iter as f64),
};
raw.clamp(self.w_start.min(self.w_end), self.w_start.max(self.w_end))
}
pub fn step(&mut self) -> f64 {
let weight = self.current_weight(self.iter_count);
self.iter_count += 1;
weight
}
}
pub trait AnalyticPenalty: Send + Sync {
fn tier(&self) -> PenaltyTier;
fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64;
fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64>;
fn hessian_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
assert!(
rho.iter().all(|value| value.is_finite()),
"analytic-penalty rho must be finite"
);
if target.is_empty() {
Some(Array1::zeros(0))
} else {
None
}
}
fn hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
let diag = self.hessian_diag(target, rho).unwrap_or_else(|| {
panic!(
"AnalyticPenalty::hvp default reached for `{}`, whose Hessian is \
not diagonal (hessian_diag returned None). Such a penalty must \
override `hvp` with its closed-form Hessian-vector product; the \
default never finite-differences.",
self.name()
)
});
assert_eq!(diag.len(), v.len(), "hvp dimension mismatch");
let mut out = Array1::<f64>::zeros(v.len());
for i in 0..v.len() {
out[i] = diag[i] * v[i];
}
out
}
fn psd_majorizer_diag(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
) -> Option<Array1<f64>> {
self.hessian_diag(target, rho)
}
fn psd_majorizer_hvp(
&self,
target: ArrayView1<'_, f64>,
rho: ArrayView1<'_, f64>,
v: ArrayView1<'_, f64>,
) -> Array1<f64> {
if let Some(diag) = self.psd_majorizer_diag(target, rho) {
assert_eq!(diag.len(), v.len(), "psd_majorizer_hvp dimension mismatch");
let mut out = Array1::<f64>::zeros(v.len());
for i in 0..v.len() {
out[i] = diag[i] * v[i];
}
return out;
}
self.hvp(target, rho, v)
}
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64>;
fn rho_count(&self) -> usize;
fn name(&self) -> &str;
fn apply_schedule(&mut self, iter: usize) {
assert!(
iter < 1_000_000,
"apply_schedule received implausible outer iteration {iter}",
);
}
}
pub(crate) fn advance_scalar_weight(
weight: &mut f64,
schedule: &mut Option<ScalarWeightSchedule>,
iter: usize,
) {
if let Some(schedule) = schedule.as_mut() {
*weight = schedule.current_weight(iter);
schedule.iter_count = iter + 1;
}
}
macro_rules! impl_with_weight_schedule {
($field:ident) => {
#[must_use]
pub fn with_weight_schedule(mut self, schedule: ScalarWeightSchedule) -> Self {
self.$field = schedule.current_weight(schedule.iter_count);
self.weight_schedule = Some(schedule);
self
}
};
}
macro_rules! impl_scalar_apply_schedule {
($field:ident) => {
fn apply_schedule(&mut self, iter: usize) {
advance_scalar_weight(&mut self.$field, &mut self.weight_schedule, iter);
}
};
}
macro_rules! impl_learnable_weight_grad_rho {
() => {
fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
if !self.learnable_weight {
return Array1::<f64>::zeros(0);
}
let mut out = Array1::<f64>::zeros(1);
out[self.rho_index] = self.value(target, rho);
out
}
};
}
macro_rules! impl_learnable_weight_rho_count {
() => {
fn rho_count(&self) -> usize {
usize::from(self.learnable_weight)
}
};
}