use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
pub enum ResidualErrorModel {
Constant {
a: f64,
},
Proportional {
b: f64,
},
Combined {
a: f64,
b: f64,
},
Exponential {
sigma: f64,
},
}
impl Default for ResidualErrorModel {
fn default() -> Self {
ResidualErrorModel::Constant { a: 1.0 }
}
}
impl ResidualErrorModel {
pub fn constant(a: f64) -> Self {
ResidualErrorModel::Constant { a }
}
pub fn proportional(b: f64) -> Self {
ResidualErrorModel::Proportional { b }
}
pub fn combined(a: f64, b: f64) -> Self {
ResidualErrorModel::Combined { a, b }
}
pub fn exponential(sigma: f64) -> Self {
ResidualErrorModel::Exponential { sigma }
}
pub fn sigma(&self, prediction: f64) -> f64 {
let raw_sigma = match self {
ResidualErrorModel::Constant { a } => *a,
ResidualErrorModel::Proportional { b } => b * prediction.abs(),
ResidualErrorModel::Combined { a, b } => {
(a.powi(2) + b.powi(2) * prediction.powi(2)).sqrt()
}
ResidualErrorModel::Exponential { sigma } => *sigma,
};
raw_sigma.max(f64::EPSILON.sqrt())
}
pub fn variance(&self, prediction: f64) -> f64 {
let sigma = self.sigma(prediction);
sigma.powi(2)
}
pub fn weighted_squared_residual(&self, observation: f64, prediction: f64) -> f64 {
let residual = observation - prediction;
let residual_sq = residual * residual;
match self {
ResidualErrorModel::Constant { .. } => {
residual_sq
}
ResidualErrorModel::Proportional { .. } => {
let pred_sq = prediction.powi(2).max(f64::EPSILON);
residual_sq / pred_sq
}
ResidualErrorModel::Combined { a, b } => {
let variance = (a.powi(2) + b.powi(2) * prediction.powi(2)).max(f64::EPSILON);
residual_sq / variance
}
ResidualErrorModel::Exponential { .. } => {
residual_sq
}
}
}
pub fn log_likelihood(&self, observation: f64, prediction: f64) -> f64 {
let sigma = self.sigma(prediction);
let residual = observation - prediction;
let normalized_residual = residual / sigma;
-0.5 * (std::f64::consts::TAU.ln() + 2.0 * sigma.ln() + normalized_residual.powi(2))
}
pub fn with_updated_sigma(self, new_sigma: f64) -> Self {
match self {
ResidualErrorModel::Constant { .. } => ResidualErrorModel::Constant { a: new_sigma },
ResidualErrorModel::Proportional { .. } => {
ResidualErrorModel::Proportional { b: new_sigma }
}
ResidualErrorModel::Combined { a: _, b } => {
ResidualErrorModel::Combined { a: new_sigma, b }
}
ResidualErrorModel::Exponential { .. } => {
ResidualErrorModel::Exponential { sigma: new_sigma }
}
}
}
pub fn primary_parameter(&self) -> f64 {
match self {
ResidualErrorModel::Constant { a } => *a,
ResidualErrorModel::Proportional { b } => *b,
ResidualErrorModel::Combined { a, .. } => *a,
ResidualErrorModel::Exponential { sigma } => *sigma,
}
}
pub fn is_proportional(&self) -> bool {
matches!(self, ResidualErrorModel::Proportional { .. })
}
pub fn is_constant(&self) -> bool {
matches!(self, ResidualErrorModel::Constant { .. })
}
pub fn is_combined(&self) -> bool {
matches!(self, ResidualErrorModel::Combined { .. })
}
pub fn is_exponential(&self) -> bool {
matches!(self, ResidualErrorModel::Exponential { .. })
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ResidualErrorModels {
models: Vec<ResidualErrorModel>,
}
impl ResidualErrorModels {
pub fn new() -> Self {
Self { models: vec![] }
}
pub fn add(mut self, outeq: usize, model: ResidualErrorModel) -> Self {
if outeq >= self.models.len() {
self.models.resize(outeq + 1, ResidualErrorModel::default());
}
self.models[outeq] = model;
self
}
pub fn get(&self, outeq: usize) -> Option<&ResidualErrorModel> {
self.models.get(outeq)
}
pub fn get_mut(&mut self, outeq: usize) -> Option<&mut ResidualErrorModel> {
self.models.get_mut(outeq)
}
pub fn sigma(&self, outeq: usize, prediction: f64) -> Option<f64> {
self.models.get(outeq).map(|m| m.sigma(prediction))
}
pub fn len(&self) -> usize {
self.models.len()
}
pub fn is_empty(&self) -> bool {
self.models.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (usize, &ResidualErrorModel)> {
self.models.iter().enumerate()
}
pub fn log_likelihood(&self, outeq: usize, observation: f64, prediction: f64) -> Option<f64> {
self.models
.get(outeq)
.map(|m| m.log_likelihood(observation, prediction))
}
pub fn total_log_likelihood<I>(&self, obs_pred_pairs: I) -> f64
where
I: IntoIterator<Item = (usize, f64, f64)>,
{
let mut total = 0.0;
for (outeq, obs, pred) in obs_pred_pairs {
match self.log_likelihood(outeq, obs, pred) {
Some(ll) => total += ll,
None => return f64::NEG_INFINITY,
}
}
total
}
pub fn update_sigma(&mut self, new_sigma: f64) {
for model in &mut self.models {
*model = model.with_updated_sigma(new_sigma);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_error() {
let model = ResidualErrorModel::constant(0.5);
assert!((model.sigma(0.0) - 0.5).abs() < 1e-10);
assert!((model.sigma(100.0) - 0.5).abs() < 1e-10);
assert!((model.sigma(-50.0) - 0.5).abs() < 1e-10);
}
#[test]
fn test_proportional_error() {
let model = ResidualErrorModel::proportional(0.1);
assert!((model.sigma(100.0) - 10.0).abs() < 1e-10);
assert!((model.sigma(50.0) - 5.0).abs() < 1e-10);
assert!((model.sigma(-100.0) - 10.0).abs() < 1e-10);
}
#[test]
fn test_combined_error() {
let model = ResidualErrorModel::combined(0.5, 0.1);
assert!((model.sigma(0.0) - 0.5).abs() < 1e-10);
assert!((model.sigma(100.0) - 100.25_f64.sqrt()).abs() < 1e-10);
}
#[test]
fn test_weighted_residual() {
let model = ResidualErrorModel::constant(1.0);
let wr = model.weighted_squared_residual(5.0, 3.0);
assert!((wr - 4.0).abs() < 1e-10);
let prop_model = ResidualErrorModel::proportional(0.1);
let wr2 = prop_model.weighted_squared_residual(12.0, 10.0);
assert!((wr2 - 0.04).abs() < 1e-10);
}
#[test]
fn test_sigma_cutoff() {
let model = ResidualErrorModel::proportional(0.1);
let sigma = model.sigma(0.0);
assert!(sigma > 0.0);
assert!(sigma >= f64::EPSILON.sqrt());
}
#[test]
fn test_log_likelihood() {
let model = ResidualErrorModel::constant(1.0);
let ll = model.log_likelihood(1.0, 0.0);
let expected = -0.5 * (std::f64::consts::TAU.ln() + 1.0);
assert!((ll - expected).abs() < 1e-10);
}
#[test]
fn test_residual_error_models_collection() {
let models = ResidualErrorModels::new()
.add(0, ResidualErrorModel::constant(0.5))
.add(1, ResidualErrorModel::proportional(0.1));
assert_eq!(models.len(), 2);
assert!(models.get(0).unwrap().is_constant());
assert!(models.get(1).unwrap().is_proportional());
assert!((models.sigma(0, 100.0).unwrap() - 0.5).abs() < 1e-10);
assert!((models.sigma(1, 100.0).unwrap() - 10.0).abs() < 1e-10);
}
}