use nalgebra::{
DefaultAllocator, Dim, OVector, RealField, Scalar, U1, VectorView, allocator::Allocator,
};
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
#[serde(bound(serialize = "D: Serialize, OVector<(Option<T>, Option<T>), D>: Serialize"))]
#[serde(bound(
deserialize = "D: Deserialize<'de>, OVector<(Option<T>, Option<T>), D>: Deserialize<'de>"
))]
pub enum Domain<T, D>
where
T: Scalar,
D: Dim,
DefaultAllocator: Allocator<D>,
{
UDomain(D),
MDomain(OVector<(Option<T>, Option<T>), D>),
}
impl<T> Domain<T, U1>
where
T: RealField,
{
pub fn inner(&self) -> Option<(Option<T>, Option<T>)> {
match self {
Domain::UDomain(_) => None,
Domain::MDomain(sdoms) => Some(sdoms[0].clone()),
}
}
}
impl<T, D> Domain<T, D>
where
T: RealField,
D: Dim,
DefaultAllocator: Allocator<D>,
{
pub fn clamp<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, D, RStride, CStride>,
) -> OVector<T, D> {
match self {
Domain::UDomain(_) => sample.clone_owned(),
Domain::MDomain(sdoms) => OVector::from_iterator_generic(
sample.shape_generic().0,
U1,
sdoms.iter().enumerate().map(|(i, (opt_min, opt_max))| {
let value = &sample[i];
if let Some(min) = opt_min
&& value < min
{
return min.clone();
}
if let Some(max) = opt_max
&& value > max
{
return max.clone();
}
value.clone()
}),
),
}
}
pub fn contains<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, D, RStride, CStride>,
) -> bool {
match self {
Domain::UDomain(_) => true,
Domain::MDomain(sdoms) => sdoms.iter().zip(sample).all(|(sdom, value)| match sdom {
(Some(min), Some(max)) => (value >= min) & (value <= max),
(Some(min), None) => value >= min,
(None, Some(max)) => value <= max,
(None, None) => true,
}),
}
}
pub fn maximum_values(&self) -> OVector<Option<T>, D> {
match self {
Domain::UDomain(dim) => OVector::from_element_generic(*dim, U1, None),
Domain::MDomain(sdoms) => OVector::from_iterator_generic(
sdoms.shape_generic().0,
U1,
sdoms.iter().map(|sdom| sdom.1.clone()),
),
}
}
pub fn minimum_values(&self) -> OVector<Option<T>, D> {
match self {
Domain::UDomain(dim) => OVector::from_element_generic(*dim, U1, None),
Domain::MDomain(sdoms) => OVector::from_iterator_generic(
sdoms.shape_generic().0,
U1,
sdoms.iter().map(|sdom| sdom.0.clone()),
),
}
}
pub fn new_mdomain(domains: OVector<(Option<T>, Option<T>), D>) -> Self {
Domain::MDomain(domains)
}
pub fn new_udomain(dim: D) -> Self {
Domain::UDomain(dim)
}
pub fn shape_generic(&self) -> D {
match self {
Domain::UDomain(udom) => *udom,
Domain::MDomain(sdoms) => sdoms.shape_generic().0,
}
}
pub fn size(&self) -> OVector<Option<T>, D> {
match self {
Domain::UDomain(udom) => OVector::from_element_generic(*udom, U1, None),
Domain::MDomain(sdoms) => OVector::from_iterator_generic(
sdoms.shape_generic().0,
U1,
sdoms.iter().map(|sdom| match sdom {
(Some(min), Some(max)) => Some(max.clone() - min.clone()),
_ => None,
}),
),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_boundaries_exclusive_outside() {
let domain = Domain::new_mdomain(OVector::from([(Some(0.0), Some(1.0))]));
let below_lower: f64 = -0.001;
assert!(!domain.contains::<U1, U1>(&OVector::from([below_lower]).as_view()));
let above_upper: f64 = 1.001;
assert!(!domain.contains::<U1, U1>(&OVector::from([above_upper]).as_view()));
let far_negative: f64 = -1e6;
assert!(!domain.contains::<U1, U1>(&OVector::from([far_negative]).as_view()));
let far_positive: f64 = 1e6;
assert!(!domain.contains::<U1, U1>(&OVector::from([far_positive]).as_view()));
}
#[test]
fn test_boundaries_inclusive() {
let domain = Domain::new_mdomain(OVector::from([(Some(0.0), Some(1.0))]));
let lower_bound: f64 = 0.0;
assert!(domain.contains::<U1, U1>(&OVector::from([lower_bound]).as_view()));
let upper_bound: f64 = 1.0;
assert!(domain.contains::<U1, U1>(&OVector::from([upper_bound]).as_view()));
let interior: f64 = 0.5;
assert!(domain.contains::<U1, U1>(&OVector::from([interior]).as_view()));
}
#[test]
fn test_clamp_above_maximum() {
let domain = Domain::new_mdomain(OVector::from([(Some(0.0), Some(1.0))]));
let sample_above = OVector::from([1.5]);
let clamped = domain.clamp::<U1, U1>(&sample_above.as_view());
assert_eq!(clamped[0], 1.0);
}
#[test]
fn test_clamp_below_minimum() {
let domain = Domain::new_mdomain(OVector::from([(Some(0.0), Some(1.0))]));
let sample_below = OVector::from([-0.5]);
let clamped = domain.clamp::<U1, U1>(&sample_below.as_view());
assert_eq!(clamped[0], 0.0);
}
#[test]
fn test_clamp_half_bounded_lower() {
let domain = Domain::new_mdomain(OVector::from([(Some(0.0), None)]));
let below = OVector::from([-1.0]);
let clamped = domain.clamp::<U1, U1>(&below.as_view());
assert_eq!(clamped[0], 0.0);
let above = OVector::from([1e6]);
let clamped = domain.clamp::<U1, U1>(&above.as_view());
assert_eq!(clamped[0], 1e6);
}
#[test]
fn test_clamp_half_bounded_upper() {
let domain = Domain::new_mdomain(OVector::from([(None, Some(1.0))]));
let above = OVector::from([2.0]);
let clamped = domain.clamp::<U1, U1>(&above.as_view());
assert_eq!(clamped[0], 1.0);
let below = OVector::from([-1e6]);
let clamped = domain.clamp::<U1, U1>(&below.as_view());
assert_eq!(clamped[0], -1e6);
}
#[test]
fn test_clamp_unbounded() {
let domain = Domain::new_udomain(U1);
let sample = OVector::from([42.0]);
let clamped = domain.clamp::<U1, U1>(&sample.as_view());
assert_eq!(clamped[0], 42.0);
}
#[test]
fn test_clamp_with_explicit_bounds() {
let domain = Domain::new_mdomain(OVector::from([(Some(0.0), Some(1.0))]));
let below = OVector::from([-0.5]);
let clamped = domain.clamp::<U1, U1>(&below.as_view());
assert_eq!(clamped[0], 0.0);
let above = OVector::from([1.5]);
let clamped = domain.clamp::<U1, U1>(&above.as_view());
assert_eq!(clamped[0], 1.0);
let inside = OVector::from([0.5]);
let clamped = domain.clamp::<U1, U1>(&inside.as_view());
assert_eq!(clamped[0], 0.5);
}
#[test]
fn test_contains_unbounded() {
let domain: Domain<f64, U1> = Domain::UDomain(U1);
let sample = OVector::from([-1e6]);
assert!(domain.contains::<U1, U1>(&sample.as_view()));
}
#[test]
fn test_half_bounded_lower() {
let domain = Domain::new_mdomain(OVector::from([(Some(0.0), None)]));
assert!(domain.contains::<U1, U1>(&OVector::from([0.0]).as_view()));
assert!(domain.contains::<U1, U1>(&OVector::from([1e-10]).as_view()));
assert!(domain.contains::<U1, U1>(&OVector::from([1e6]).as_view()));
assert!(!domain.contains::<U1, U1>(&OVector::from([-1e-10]).as_view()));
}
#[test]
fn test_half_bounded_upper() {
let domain = Domain::new_mdomain(OVector::from([(None, Some(1.0))]));
assert!(domain.contains::<U1, U1>(&OVector::from([1.0]).as_view()));
assert!(domain.contains::<U1, U1>(&OVector::from([1.0 - 1e-10]).as_view()));
assert!(domain.contains::<U1, U1>(&OVector::from([-1e6]).as_view()));
assert!(!domain.contains::<U1, U1>(&OVector::from([1.0 + 1e-10]).as_view()));
}
#[test]
fn test_maximum_values() {
let domain =
Domain::new_mdomain(OVector::from([(Some(0.0), Some(1.0)), (Some(-5.0), None)]));
let maxes = domain.maximum_values();
assert_eq!(maxes[0], Some(1.0));
assert_eq!(maxes[1], None);
}
}