use rand::{Rng, RngCore};
use crate::traits::SearchSpace;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Discretization {
Round,
Truncate,
Floor,
}
impl Discretization {
pub(crate) fn apply(self, x: f64) -> i64 {
match self {
Discretization::Round => x.round() as i64,
Discretization::Truncate => x.trunc() as i64,
Discretization::Floor => x.floor() as i64,
}
}
}
#[derive(Debug, Clone)]
pub struct IntegerSpace {
bounds: Vec<(i64, i64)>,
discretization: Discretization,
}
impl IntegerSpace {
pub fn new(bounds: Vec<(i64, i64)>) -> Self {
for (i, (lo, hi)) in bounds.iter().enumerate() {
assert!(lo <= hi, "invalid bound in dimension {i}: {lo} > {hi}");
}
Self {
bounds,
discretization: Discretization::Round,
}
}
pub fn with_discretization(mut self, d: Discretization) -> Self {
self.discretization = d;
self
}
pub fn uniform(dim: usize, lo: i64, hi: i64) -> Self {
Self::new(vec![(lo, hi); dim])
}
pub fn bounds(&self) -> &[(i64, i64)] {
&self.bounds
}
}
impl SearchSpace for IntegerSpace {
type Scalar = i64;
fn dim(&self) -> usize {
self.bounds.len()
}
fn sample(&self, rng: &mut dyn RngCore) -> Vec<f64> {
self.bounds
.iter()
.map(|&(lo, hi)| rng.gen_range(lo as f64..=hi as f64))
.collect()
}
fn sample_velocity(&self, rng: &mut dyn RngCore) -> Vec<f64> {
self.bounds
.iter()
.map(|&(lo, hi)| {
let range = (hi - lo) as f64;
rng.gen_range(-range..=range)
})
.collect()
}
fn clamp(&self, position: &mut [f64]) {
for (x, &(lo, hi)) in position.iter_mut().zip(&self.bounds) {
*x = x.clamp(lo as f64, hi as f64);
}
}
fn enforce_bounds(
&self,
position: &mut [f64],
velocity: &mut [f64],
handling: crate::traits::BoundaryHandling,
rng: &mut dyn RngCore,
) {
super::apply_boundary(
position,
velocity,
|i| {
let (lo, hi) = self.bounds[i];
(lo as f64, hi as f64)
},
handling,
rng,
);
}
fn decode(&self, raw: &[f64]) -> Vec<i64> {
raw.iter()
.zip(&self.bounds)
.map(|(&x, &(lo, hi))| self.discretization.apply(x).clamp(lo, hi))
.collect()
}
fn span(&self) -> Vec<(f64, f64)> {
self.bounds
.iter()
.map(|&(lo, hi)| (lo as f64, hi as f64))
.collect()
}
}