use std::fmt::{Debug, Display};
use dyn_clone::DynClone;
use fastrand::Rng;
use serde::{Deserialize, Serialize};
use crate::{core::utils::SampleFloat, Float};
#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum Bound {
#[default]
NoBound,
LowerBound(Float),
UpperBound(Float),
LowerAndUpperBound(Float, Float),
}
impl Display for Bound {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "({}, {})", self.lower(), self.upper())
}
}
impl Bound {
pub const fn as_options(&self) -> (Option<Float>, Option<Float>) {
match self {
Self::NoBound => (None, None),
Self::LowerBound(l) => (Some(*l), None),
Self::UpperBound(u) => (None, Some(*u)),
Self::LowerAndUpperBound(l, u) => (Some(*l), Some(*u)),
}
}
pub const fn as_floats(&self) -> (Float, Float) {
(self.lower(), self.upper())
}
pub const fn has_upper(&self) -> bool {
match self {
Self::NoBound | Self::LowerBound(_) => false,
Self::UpperBound(_) | Self::LowerAndUpperBound(_, _) => true,
}
}
pub const fn has_lower(&self) -> bool {
match self {
Self::NoBound | Self::UpperBound(_) => false,
Self::LowerBound(_) | Self::LowerAndUpperBound(_, _) => true,
}
}
pub const fn lower(&self) -> Float {
match self {
Self::NoBound | Self::UpperBound(_) => Float::NEG_INFINITY,
Self::LowerBound(l) => *l,
Self::LowerAndUpperBound(l, _) => *l,
}
}
pub const fn upper(&self) -> Float {
match self {
Self::NoBound | Self::LowerBound(_) => Float::INFINITY,
Self::UpperBound(u) => *u,
Self::LowerAndUpperBound(_, u) => *u,
}
}
pub fn random(&self, rng: &mut Rng) -> Float {
match self.as_options() {
(None, None) => rng.normal(0.0, 1.0),
(Some(l), None) => l + Float::abs(rng.normal(0.0, 1.0)),
(None, Some(u)) => u - Float::abs(rng.normal(0.0, 1.0)),
(Some(l), Some(u)) => rng.range(l, u),
}
}
pub fn contains(&self, value: Float) -> bool {
if value.is_nan() {
return false;
}
(!self.has_lower() || value >= self.lower()) && (!self.has_upper() || value <= self.upper())
}
pub fn get_excess(&self, value: Float) -> Float {
match self.as_options() {
(None, None) => 0.0,
(None, Some(u)) => Float::max(value - u, 0.0),
(Some(l), None) => Float::min(value - l, 0.0),
(Some(l), Some(u)) => Float::max(value - u, 0.0) + Float::min(value - l, 0.0),
}
}
pub fn at_bound(&self, value: Float, tol: Float) -> bool {
(self.has_upper() && (value - self.upper()).abs() < tol)
|| (self.has_lower() && (value - self.lower()).abs() < tol)
}
pub fn clip_value(&self, value: Float) -> Float {
match self.as_options() {
(None, None) => value,
(None, Some(u)) => Float::min(value, u),
(Some(l), None) => Float::max(value, l),
(Some(l), Some(u)) => value.clamp(l, u),
}
}
}
impl From<(Float, Float)> for Bound {
fn from(value: (Float, Float)) -> Self {
let (a, b) = value;
let (l, u) = if a < b { (a, b) } else { (b, a) };
match (l.is_finite(), u.is_finite()) {
(true, true) => Self::LowerAndUpperBound(l, u),
(true, false) => Self::LowerBound(l),
(false, true) => Self::UpperBound(u),
(false, false) => Self::NoBound,
}
}
}
impl From<(Option<Float>, Option<Float>)> for Bound {
fn from(value: (Option<Float>, Option<Float>)) -> Self {
match value {
(Some(a), Some(b)) => (a, b),
(Some(l), None) => (l, Float::INFINITY),
(None, Some(u)) => (Float::NEG_INFINITY, u),
(None, None) => (Float::NEG_INFINITY, Float::INFINITY),
}
.into()
}
}
#[typetag::serde]
pub trait BoundLike: DynClone + Debug + Send + Sync {
fn to_internal_impl(&self, bound: Bound, x: Float) -> Float;
fn d_to_internal_impl(&self, bound: Bound, x: Float) -> Float;
fn dd_to_internal_impl(&self, bound: Bound, x: Float) -> Float;
fn to_external_impl(&self, bound: Bound, z: Float) -> Float;
fn d_to_external_impl(&self, bound: Bound, z: Float) -> Float;
fn dd_to_external_impl(&self, bound: Bound, z: Float) -> Float;
}
dyn_clone::clone_trait_object!(BoundLike);
#[cfg(test)]
mod tests {
use super::*;
use fastrand::Rng;
#[test]
fn test_bounds_creation() {
let bound = Bound::from((None, None));
assert_eq!(bound.lower(), Float::NEG_INFINITY);
assert_eq!(bound.upper(), Float::INFINITY);
let bound = Bound::from((None, Some(1.2)));
assert_eq!(bound.lower(), Float::NEG_INFINITY);
assert_eq!(bound.upper(), 1.2);
let bound = Bound::from((Some(-3.4), None));
assert_eq!(bound.lower(), -3.4);
assert_eq!(bound.upper(), Float::INFINITY);
let bound = Bound::from((Some(-3.4), Some(1.2)));
assert_eq!(bound.lower(), -3.4);
assert_eq!(bound.upper(), 1.2);
let bound = Bound::from((Some(1.2), Some(-3.4)));
assert_eq!(bound.lower(), -3.4);
assert_eq!(bound.upper(), 1.2);
let bound = Bound::from((Float::NEG_INFINITY, Float::INFINITY));
assert_eq!(bound.lower(), Float::NEG_INFINITY);
assert_eq!(bound.upper(), Float::INFINITY);
let bound = Bound::from((Float::INFINITY, Float::NEG_INFINITY));
assert_eq!(bound.lower(), Float::NEG_INFINITY);
assert_eq!(bound.upper(), Float::INFINITY);
let bound = Bound::from((Float::NEG_INFINITY, 1.2));
assert_eq!(bound.lower(), Float::NEG_INFINITY);
assert_eq!(bound.upper(), 1.2);
let bound = Bound::from((-3.4, Float::INFINITY));
assert_eq!(bound.lower(), -3.4);
assert_eq!(bound.upper(), Float::INFINITY);
let bound = Bound::from((-3.4, 1.2));
assert_eq!(bound.lower(), -3.4);
assert_eq!(bound.upper(), 1.2);
let bound = Bound::from((1.2, -3.4));
assert_eq!(bound.lower(), -3.4);
assert_eq!(bound.upper(), 1.2);
}
#[test]
fn test_bound_contains_is_inclusive_at_finite_endpoints() {
let bounded = Bound::LowerAndUpperBound(-1.0, 1.0);
assert!(bounded.contains(-1.0));
assert!(bounded.contains(1.0));
assert!(!bounded.contains(1.1));
let lower = Bound::LowerBound(2.0);
assert!(lower.contains(2.0));
let upper = Bound::UpperBound(3.0);
assert!(upper.contains(3.0));
}
#[test]
fn test_bound_random_with_infinite_endpoints_stays_finite() {
let mut rng = Rng::with_seed(0);
for bound in [
Bound::NoBound,
Bound::LowerBound(1.5),
Bound::UpperBound(-2.5),
] {
let sample = bound.random(&mut rng);
assert!(sample.is_finite());
assert!(bound.contains(sample));
}
}
}