use crate::{Interval, prelude::*};
use std::{cmp, fmt, ops::Range};
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct Equipartition<const N: usize> {
lb: f64,
ub: f64,
}
impl<const N: usize> Equipartition<N> {
pub fn new(lb: f64, ub: f64) -> Equipartition<N> {
if N == 0 {
panic!("A partition must have a number partitions of 1 or greater.")
}
Equipartition { lb, ub, }
}
pub fn from_interval<I: Into<Interval>>(d: I) -> Equipartition<N> {
let interval = d.into();
Equipartition {
lb: interval.lb.expect("Must be a bounded interval."),
ub: interval.ub.expect("Must be a bounded interval."),
}
}
#[inline]
pub fn n_partitions(&self) -> usize { N }
#[inline]
pub fn partition_width(&self) -> f64 { (self.ub - self.lb) / N as f64 }
pub fn centres(&self) -> [f64; N] {
let w = self.partition_width();
let hw = w / 2.0;
let mut output = [f64::default(); N];
for i in 0..N {
output[i] = self.lb + w * ((i + 1) as f64) - hw;
}
output
}
pub fn edges(&self) -> [f64; N] {
let w = self.partition_width();
let mut output = [f64::default(); N];
for i in 0..N {
output[i] = self.lb + w * (i as f64);
}
output
}
pub fn to_partition(&self, val: f64) -> usize {
let clipped = clip!(self.lb, val, self.ub);
let diff = clipped - self.lb;
let range = self.ub - self.lb;
let i = ((N as f64) * diff / range).floor() as usize;
if i >= N { N - 1 } else { i }
}
}
impl<const N: usize> Space for Equipartition<N> {
const DIM: usize = 1;
type Value = usize;
fn card(&self) -> Card { Card::Finite(N) }
fn contains(&self, val: &usize) -> bool { *val < N }
}
impl<const N: usize> OrderedSpace for Equipartition<N> {
fn min(&self) -> Option<usize> { Some(0) }
fn max(&self) -> Option<usize> { Some(N - 1) }
}
impl<const N: usize> FiniteSpace for Equipartition<N> {
fn to_ordinal(&self) -> Range<Self::Value> { 0..N }
}
impl<const N: usize, const M: usize> cmp::PartialEq<Equipartition<M>> for Equipartition<N> {
fn eq(&self, other: &Equipartition<M>) -> bool {
N.eq(&M) && self.lb.eq(&other.lb) && self.ub.eq(&other.ub)
}
}
impl<const N: usize> fmt::Display for Equipartition<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match N {
n if n == 1 => write!(f, "{{{} = x0, x1 = {}}}", self.lb, self.ub),
n if n == 2 => write!(f, "{{{} = x0, x1, x2 = {}}}", self.lb, self.ub),
n => write!(f, "{{{} = x0, x1, ..., x{} = {}}}", self.lb, n, self.ub),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "serialize")]
extern crate serde_test;
#[cfg(feature = "serialize")]
use self::serde_test::{assert_tokens, Token};
#[test]
fn test_from_interval() {
assert_eq!(
Equipartition::<5>::new(0.0, 5.0),
Equipartition::<5>::from_interval(Interval::bounded(0.0, 5.0))
);
}
#[test]
fn test_density() {
assert_eq!(Equipartition::<5>::new(0.0, 5.0).n_partitions(), 5);
assert_eq!(Equipartition::<10>::new(0.0, 5.0).n_partitions(), 10);
assert_eq!(Equipartition::<100>::new(-5.0, 5.0).n_partitions(), 100);
}
#[test]
fn test_partition_width() {
assert_eq!(Equipartition::<5>::new(0.0, 5.0).partition_width(), 1.0);
assert_eq!(Equipartition::<10>::new(0.0, 5.0).partition_width(), 0.5);
assert_eq!(Equipartition::<10>::new(-5.0, 5.0).partition_width(), 1.0);
}
#[test]
fn test_centres() {
assert_eq!(
Equipartition::new(0.0, 5.0).centres(),
[0.5, 1.5, 2.5, 3.5, 4.5]
);
assert_eq!(
Equipartition::new(-5.0, 5.0).centres(),
[-4.0, -2.0, 0.0, 2.0, 4.0]
);
}
#[test]
fn test_to_partition() {
let d = Equipartition::<6>::new(0.0, 5.0);
assert_eq!(d.to_partition(-1.0), 0);
assert_eq!(d.to_partition(0.0), 0);
assert_eq!(d.to_partition(1.0), 1);
assert_eq!(d.to_partition(2.0), 2);
assert_eq!(d.to_partition(3.0), 3);
assert_eq!(d.to_partition(4.0), 4);
assert_eq!(d.to_partition(5.0), 5);
assert_eq!(d.to_partition(6.0), 5);
}
#[test]
fn test_dim() {
assert_eq!(Equipartition::<1>::DIM, 1);
assert_eq!(Equipartition::<5>::DIM, 1);
assert_eq!(Equipartition::<10>::DIM, 1);
}
#[test]
fn test_card() {
fn check<const N: usize>(lb: f64, ub: f64) {
let d = Equipartition::<N>::new(lb, ub);
assert_eq!(d.card(), Card::Finite(N));
}
check::<5>(0.0, 5.0);
check::<5>(-5.0, 0.0);
check::<10>(-5.0, 5.0);
}
#[test]
fn test_to_ordinal() {
fn check<const N: usize>(lb: f64, ub: f64) {
let d = Equipartition::<N>::new(lb, ub);
assert_eq!(d.to_ordinal(), 0..N);
}
check::<5>(0.0, 5.0);
check::<5>(-5.0, 0.0);
check::<10>(-5.0, 5.0);
}
#[cfg(feature = "serialize")]
#[test]
fn test_serialisation() {
fn check(lb: f64, ub: f64, n_partitions: usize) {
let d = Equipartition::new(lb, ub, n_partitions);
assert_tokens(
&d,
&[
Token::Struct {
name: "Equipartition",
len: 3,
},
Token::Str("lb"),
Token::F64(lb),
Token::Str("ub"),
Token::F64(ub),
Token::Str("n_partitions"),
Token::U64(n_partitions as u64),
Token::StructEnd,
],
);
}
check(0.0, 5.0, 5);
check(-5.0, 5.0, 10);
check(-5.0, 0.0, 5);
}
}