use crate::Domain;
use nalgebra::{Dim, RealField, SVector, Scalar, U1, VectorView};
use serde::{Deserialize, Serialize};
#[allow(missing_docs)]
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
pub enum SDomain<T> {
Bounded(T, T),
Constant(T),
LeftBounded(T),
RightBounded(T),
Unbounded,
}
impl<T> SDomain<T>
where
T: Clone + PartialOrd,
{
pub fn new(opt_min: Option<T>, opt_max: Option<T>) -> Option<Self> {
Some(match (opt_min, opt_max) {
(None, None) => Self::Unbounded,
(Some(a), None) => Self::LeftBounded(a),
(None, Some(b)) => Self::RightBounded(b),
(Some(a), Some(b)) => {
if a > b {
return None;
}
if a == b {
Self::Constant(a)
} else {
Self::Bounded(a, b)
}
}
})
}
pub fn unbounded() -> Self {
Self::Unbounded
}
}
impl<T> Domain<T, U1> for SDomain<T>
where
T: PartialOrd + Scalar + Send + Sync,
{
fn clip<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> SVector<T, 1> {
(&self).clip(sample)
}
fn contains<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> bool {
(&self).contains(sample)
}
fn maximum_values(&self) -> Option<SVector<T, 1>> {
(&self).maximum_values()
}
fn minimum_values(&self) -> Option<SVector<T, 1>> {
(&self).minimum_values()
}
fn size(&self) -> SVector<Option<T>, 1>
where
T: RealField,
{
(&self).size()
}
}
impl<T> Domain<T, U1> for &SDomain<T>
where
T: PartialOrd + Scalar + Send + Sync,
{
fn clip<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> SVector<T, 1> {
let value = sample.clone_owned();
match self {
SDomain::Bounded(a, b) => {
if value[0] < *a {
SVector::from([a.clone()])
} else if value[0] > *b {
SVector::from([b.clone()])
} else {
value
}
}
SDomain::Constant(a) => SVector::from([a.clone()]),
SDomain::LeftBounded(a) => {
if value[0] < *a {
SVector::from([a.clone()])
} else {
value
}
}
SDomain::RightBounded(b) => {
if value[0] > *b {
SVector::from([b.clone()])
} else {
value
}
}
SDomain::Unbounded => value,
}
}
fn contains<RStride: Dim, CStride: Dim>(
&self,
sample: &VectorView<T, U1, RStride, CStride>,
) -> bool {
match self {
SDomain::Bounded(a, b) => (sample[0] > *a) & (sample[0] < *b),
SDomain::Constant(a) => sample[0] == *a,
SDomain::LeftBounded(a) => sample[0] > *a,
SDomain::RightBounded(b) => sample[0] < *b,
SDomain::Unbounded => true,
}
}
fn maximum_values(&self) -> Option<SVector<T, 1>> {
match self {
SDomain::Bounded(_, b) => Some(SVector::from([b.clone()])),
SDomain::Constant(value) => Some(SVector::from([value.clone()])),
SDomain::RightBounded(b) => Some(SVector::from([b.clone()])),
_ => None,
}
}
fn minimum_values(&self) -> Option<SVector<T, 1>> {
match self {
SDomain::Bounded(a, _) => Some(SVector::from([a.clone()])),
SDomain::Constant(value) => Some(SVector::from([value.clone()])),
SDomain::LeftBounded(a) => Some(SVector::from([a.clone()])),
_ => None,
}
}
fn size(&self) -> SVector<Option<T>, 1>
where
T: RealField,
{
SVector::from([match self {
SDomain::Bounded(a, b) => Some(b.clone() - a.clone()),
SDomain::Constant(..) => Some(T::zero()),
_ => None,
}])
}
}