#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_sign_loss)]
use crate::{
basis::{Basis, DifferentialBasis},
display::PolynomialDisplay,
score::{ModelScoreProvider, ScoringMethod},
value::{FloatClampedCast, Value},
CurveFit,
};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PenaltyWeight {
None,
Small,
Medium,
Large,
Hard,
Custom(f64),
}
impl PenaltyWeight {
#[inline(always)]
fn to_lambda<T: Value>(self) -> T {
let l = match self {
PenaltyWeight::None => 0.0,
PenaltyWeight::Small => 1e-6,
PenaltyWeight::Medium => 1e-4,
PenaltyWeight::Large => 1e-2,
PenaltyWeight::Hard => 1e6,
PenaltyWeight::Custom(value) => value,
};
l.clamped_cast()
}
}
impl<T: Value> From<T> for PenaltyWeight {
fn from(value: T) -> Self {
PenaltyWeight::Custom(value.clamped_cast())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SamplingStrategy {
Percentage(f64),
Count(usize),
Total,
}
impl SamplingStrategy {
#[must_use]
pub fn new_percentage(percent: f64) -> Self {
SamplingStrategy::Percentage(percent.clamp(0.0, 1.0))
}
#[must_use]
pub fn new_count(count: usize) -> Self {
SamplingStrategy::Count(count)
}
#[must_use]
pub fn new_total() -> Self {
SamplingStrategy::Total
}
#[must_use]
pub fn count(&self, data_len: usize) -> usize {
match self {
SamplingStrategy::Percentage(percent) => {
(percent.clamp(0.0, 1.0) * (data_len as f64)).round() as usize
}
SamplingStrategy::Count(count) => (*count).clamp(1, data_len),
SamplingStrategy::Total => data_len,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MonotonicityDirection {
Infer,
Increasing,
Decreasing,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ShapeConstraint<T: Value> {
base_score_provider: ScoringMethod,
lambda_curvature: T,
lambda_monotonic: T,
monotonic_direction: MonotonicityDirection,
samples: SamplingStrategy,
}
impl<T: Value> ShapeConstraint<T> {
#[must_use]
pub fn new_with_provider(
base_score_provider: impl Into<ScoringMethod>,
sampling_strategy: SamplingStrategy,
) -> Self {
Self {
base_score_provider: base_score_provider.into(),
lambda_curvature: T::zero(),
lambda_monotonic: T::zero(),
monotonic_direction: MonotonicityDirection::Infer,
samples: sampling_strategy,
}
}
#[must_use]
pub fn new_rmse(sampling_strategy: SamplingStrategy) -> Self {
Self::new_with_provider(ScoringMethod::RootMeanSquaredError, sampling_strategy)
}
#[must_use]
pub fn new_mae(sampling_strategy: SamplingStrategy) -> Self {
Self::new_with_provider(ScoringMethod::MeanAbsoluteError, sampling_strategy)
}
#[must_use]
pub fn with_curvature_penalty(mut self, curvature_penalty: impl Into<PenaltyWeight>) -> Self {
self.lambda_curvature = curvature_penalty.into().to_lambda();
self
}
#[must_use]
pub fn with_monotonic_penalty(
mut self,
monotonic_penalty: impl Into<PenaltyWeight>,
direction: MonotonicityDirection,
) -> Self {
self.lambda_monotonic = monotonic_penalty.into().to_lambda();
self.monotonic_direction = direction;
self
}
}
impl<B, T> ModelScoreProvider<B, T> for ShapeConstraint<T>
where
T: Value,
B: Basis<T> + PolynomialDisplay<T> + DifferentialBasis<T>,
B::B2: DifferentialBasis<T>,
{
fn minimum_significant_distance(&self) -> Option<usize> {
None
}
fn score(
&self,
model: &CurveFit<B, T>,
y: impl Iterator<Item = T>,
y_fit: impl Iterator<Item = T>,
k: T,
) -> T {
let base_score = self.base_score_provider.score(model, y, y_fit, k);
let range = model.x_range();
let min = *range.start();
let max = *range.end();
let x_range = max - min;
let y_range = model.y_range();
let y_range = *y_range.end() - *y_range.start();
let mut curvature = T::zero();
let mut monotonic = T::zero();
let Ok(d1) = model.as_polynomial().derivative() else {
return base_score;
};
let d2 = d1.derivative().ok();
let samples = self.samples.count(model.data().len());
if samples < 2 {
return base_score;
}
let mono_epsilon = Value::abs(max - min) * T::from_f64(1e-8).unwrap_or(T::epsilon());
let stepsize = (max - min) / T::from_positive_int(samples - 1);
let mut monotonic_direction = self.monotonic_direction;
if !self.lambda_monotonic.is_zero() && monotonic_direction == MonotonicityDirection::Infer {
let (mut pos, mut neg) = (0, 0);
for i in 0..samples {
let x = min + T::from_positive_int(i) * stepsize;
if x > max {
break;
}
let v1 = d1.y(x);
if v1 > mono_epsilon {
pos += 1;
} else if v1 < -mono_epsilon {
neg += 1;
}
}
monotonic_direction = if pos >= neg {
MonotonicityDirection::Increasing
} else {
MonotonicityDirection::Decreasing
};
}
for i in 0..samples {
let x = min + T::from_positive_int(i) * stepsize;
if x > max {
break;
}
let mut v1 = d1.y(x);
let mut v2 = d2.as_ref().map_or(T::zero(), |d| d.y(x));
v1 *= x_range / y_range;
v2 *= x_range * x_range / y_range;
curvature += v2 * v2;
let violation = match monotonic_direction {
MonotonicityDirection::Increasing => Value::min(v1, T::zero()),
MonotonicityDirection::Decreasing => Value::max(v1, T::zero()),
MonotonicityDirection::Infer => T::zero(),
};
if Value::abs(violation) > mono_epsilon {
monotonic += violation * violation;
}
}
let curvature_penalty = self.lambda_curvature * stepsize;
let monotonic_penalty = self.lambda_monotonic * stepsize;
base_score + curvature_penalty * curvature + monotonic_penalty * monotonic
}
}