use super::Gene;
use rand::{Rng, rngs::ThreadRng};
use serde::{Deserialize, Serialize};
use tracing::instrument;
#[derive(Debug, thiserror::Error)]
pub enum GeneBoundError {
#[error(
"InvalidBounds: lower bound must be smaller than upper. lower = {lower}, upper={upper}"
)]
InvalidBound { lower: f64, upper: f64 },
#[error("StepsOverflow: steps is too large. steps={steps}, max={max}")]
StepsOverflow { steps: u32, max: i32 },
#[error("ZeroSteps: number of steps must be greater than 0")]
ZeroSteps,
#[error("ValueOutOfBounds: value {value} is outside bounds [{lower}, {upper}]")]
ValueOutOfBounds { value: f64, lower: f64, upper: f64 },
}
impl GeneBoundError {
pub(crate) fn steps_overflow(steps: u32) -> Self {
Self::StepsOverflow {
steps,
max: i32::MAX,
}
}
pub(crate) fn invalid_bound(lower: f64, upper: f64) -> Self {
Self::InvalidBound { lower, upper }
}
pub(crate) fn zero_steps() -> Self {
Self::ZeroSteps
}
pub(crate) fn value_out_of_bounds(value: f64, lower: f64, upper: f64) -> Self {
Self::ValueOutOfBounds {
value,
lower,
upper,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(test, derive(PartialEq))]
pub struct GeneBounds {
pub(crate) lower_scaled: i64,
pub(crate) upper_scaled: i64,
pub(crate) steps: u32,
pub(crate) scale_factor: i64,
}
impl GeneBounds {
#[instrument(level = "debug", fields(lower = lower, upper = upper, steps = steps, precision = precision))]
pub fn decimal(
lower: f64,
upper: f64,
steps: u32,
precision: u8,
) -> Result<Self, GeneBoundError> {
Self::validate_bounds(lower, upper, steps)?;
let scale_factor = 10_i64.pow(precision as u32);
let lower_scaled = (lower * scale_factor as f64).round() as i64;
let upper_scaled = (upper * scale_factor as f64).round() as i64;
Ok(Self {
lower_scaled,
upper_scaled,
steps,
scale_factor,
})
}
#[instrument(level = "debug", fields(lower = lower, upper = upper, steps = steps))]
pub fn integer(lower: i32, upper: i32, steps: u32) -> Result<Self, GeneBoundError> {
Self::validate_bounds(lower, upper, steps)?;
Ok(Self {
lower_scaled: lower as i64,
upper_scaled: upper as i64,
steps: steps as u32,
scale_factor: 1,
})
}
fn validate_bounds<T>(lower: T, upper: T, steps: u32) -> Result<(), GeneBoundError>
where
T: PartialOrd + Copy + Into<f64>,
{
if lower > upper {
return Err(GeneBoundError::invalid_bound(lower.into(), upper.into()));
}
if steps == 0 {
return Err(GeneBoundError::zero_steps());
}
if steps > i32::MAX as u32 {
return Err(GeneBoundError::steps_overflow(steps));
}
if lower == upper && steps > 1 {
return Err(GeneBoundError::invalid_bound(lower.into(), upper.into()));
}
Ok(())
}
pub fn decode_f64(&self, gene: Gene) -> f64 {
let range = self.upper_scaled - self.lower_scaled;
let scaled_value = self.lower_scaled + (gene * range) / (self.steps - 1) as i64;
scaled_value as f64 / self.scale_factor as f64
}
#[instrument(level = "debug", skip(rng), fields(steps = self.steps))]
pub(crate) fn random(&self, rng: &mut ThreadRng) -> Gene {
rng.random_range(0..self.steps as i64)
}
pub(crate) fn steps(&self) -> i32 {
self.steps as i32
}
pub fn from_sample(&self, sample: f64) -> Gene {
(sample * (self.steps - 1) as f64).round() as Gene
}
pub fn encode_f64(&self, value: f64) -> Result<Gene, GeneBoundError> {
let lower = self.lower_scaled as f64 / self.scale_factor as f64;
let upper = self.upper_scaled as f64 / self.scale_factor as f64;
if value < lower || value > upper {
return Err(GeneBoundError::value_out_of_bounds(value, lower, upper));
}
let range = (self.upper_scaled - self.lower_scaled) as f64;
let scaled_input = (value * self.scale_factor as f64).round() as i64;
let normalized = (scaled_input - self.lower_scaled) as f64 / range;
Ok(self.from_sample(normalized))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bounds_ordering_validation() {
assert!(GeneBounds::decimal(0.5, 0.23, 10, 6).is_err());
assert!(GeneBounds::integer(10, 5, 10).is_err());
}
#[test]
fn test_zero_steps_validation() {
assert!(GeneBounds::decimal(0.0, 1.0, 0, 6).is_err());
assert!(GeneBounds::integer(0, 10, 0).is_err());
}
#[test]
fn test_steps_overflow_validation() {
assert!(GeneBounds::decimal(0.0, 1.0, u32::MAX, 6).is_err());
assert!(GeneBounds::integer(0, 10, u32::MAX).is_err());
}
#[test]
fn test_equal_bounds_validation() {
assert!(GeneBounds::decimal(0.5, 0.5, 2, 6).is_err());
assert!(GeneBounds::integer(5, 5, 2).is_err());
assert!(GeneBounds::decimal(0.5, 0.5, 1, 6).is_ok());
assert!(GeneBounds::integer(5, 5, 1).is_ok());
}
#[test]
fn test_from_sample_boundaries() {
let bounds = GeneBounds::decimal(0.0, 10.0, 100, 3).unwrap();
assert_eq!(bounds.from_sample(0.0), 0);
assert_eq!(bounds.from_sample(1.0), 99);
assert_eq!(bounds.from_sample(0.5), 50);
}
#[test]
fn test_from_sample_single_step() {
let bounds = GeneBounds::decimal(0.0, 1.0, 1, 3).unwrap();
assert_eq!(bounds.from_sample(0.0), 0);
assert_eq!(bounds.from_sample(1.0), 0);
}
#[test]
fn test_from_sample_different_step_counts() {
let bounds = GeneBounds::decimal(0.0, 5.0, 10, 2).unwrap();
assert_eq!(bounds.from_sample(0.0), 0);
assert_eq!(bounds.from_sample(1.0), 9);
assert_eq!(bounds.from_sample(0.5), 5);
}
#[test]
fn test_bounds_inclusion() {
let bounds = GeneBounds::decimal(0.0, 1.0, 2, 3).unwrap();
assert_eq!(bounds.decode_f64(0), 0.0); assert_eq!(bounds.decode_f64(1), 1.0);
assert_eq!(bounds.from_sample(0.0), 0); assert_eq!(bounds.from_sample(1.0), 1);
let bounds_5 = GeneBounds::decimal(0.0, 4.0, 5, 1).unwrap();
assert_eq!(bounds_5.decode_f64(0), 0.0); assert_eq!(bounds_5.decode_f64(1), 1.0); assert_eq!(bounds_5.decode_f64(2), 2.0); assert_eq!(bounds_5.decode_f64(3), 3.0); assert_eq!(bounds_5.decode_f64(4), 4.0); }
#[test]
fn test_integer_bounds_internal_representation() {
let bounds = GeneBounds::integer(1, 10, 10).unwrap();
assert_eq!(bounds.lower_scaled, 1);
assert_eq!(bounds.upper_scaled, 10);
assert_eq!(bounds.steps, 10);
assert_eq!(bounds.scale_factor, 1);
assert_eq!(bounds.steps(), 10);
}
#[test]
fn test_decimal_bounds_internal_representation() {
let bounds = GeneBounds::decimal(0.23, 0.5, 500, 6).unwrap();
assert_eq!(bounds.lower_scaled, 230000); assert_eq!(bounds.upper_scaled, 500000); assert_eq!(bounds.steps, 500);
assert_eq!(bounds.scale_factor, 1_000_000);
}
#[test]
fn test_decode_f64_conversion() {
let bounds = GeneBounds::decimal(0.0, 10.0, 11, 1).unwrap();
assert_eq!(bounds.decode_f64(0), 0.0); assert_eq!(bounds.decode_f64(5), 5.0); assert_eq!(bounds.decode_f64(10), 10.0);
let precise_bounds = GeneBounds::decimal(0.0, 1.0, 101, 3).unwrap(); assert!((precise_bounds.decode_f64(0) - 0.0).abs() < 0.001);
assert!((precise_bounds.decode_f64(50) - 0.5).abs() < 0.01);
assert!((precise_bounds.decode_f64(100) - 1.0).abs() < 0.001);
}
#[test]
fn test_random_generates_different_values() {
let bounds = GeneBounds::integer(0, 1000, 100).unwrap(); let mut rng = rand::rng();
let gene1 = bounds.random(&mut rng);
let gene2 = bounds.random(&mut rng);
assert_ne!(gene1, gene2);
assert!(gene1 >= 0 && gene1 < 100);
assert!(gene2 >= 0 && gene2 < 100);
}
#[test]
fn test_encode_f64_decimal_bounds() {
let bounds = GeneBounds::decimal(0.0, 1.0, 11, 2).unwrap();
assert_eq!(bounds.encode_f64(0.0).unwrap(), 0);
assert_eq!(bounds.encode_f64(1.0).unwrap(), 10);
let gene = bounds.encode_f64(0.5).unwrap();
assert_eq!(gene, 5);
let decoded = bounds.decode_f64(gene);
assert!((decoded - 0.5).abs() < 0.01);
}
#[test]
fn test_encode_f64_integer_bounds() {
let bounds = GeneBounds::integer(10, 20, 11).unwrap();
assert_eq!(bounds.encode_f64(10.0).unwrap(), 0);
assert_eq!(bounds.encode_f64(20.0).unwrap(), 10);
let gene = bounds.encode_f64(15.0).unwrap();
assert_eq!(gene, 5);
let decoded = bounds.decode_f64(gene);
assert!((decoded - 15.0).abs() < 0.1);
}
#[test]
fn test_encode_f64_out_of_bounds() {
let bounds = GeneBounds::decimal(0.001, 1.0, 1000, 3).unwrap();
assert!(bounds.encode_f64(0.0).is_err()); assert!(bounds.encode_f64(1.1).is_err());
assert!(bounds.encode_f64(0.001).is_ok());
assert!(bounds.encode_f64(1.0).is_ok());
}
#[test]
fn test_encode_f64_precision() {
let bounds = GeneBounds::decimal(0.0, 1.0, 1001, 3).unwrap();
let gene = bounds.encode_f64(0.123).unwrap();
let decoded = bounds.decode_f64(gene);
assert!((decoded - 0.123).abs() < 0.002);
}
#[test]
fn test_encode_f64_error_message() {
let bounds = GeneBounds::decimal(0.5, 1.5, 100, 2).unwrap();
let result = bounds.encode_f64(0.3);
assert!(result.is_err());
let error = result.unwrap_err();
let error_msg = format!("{}", error);
assert!(error_msg.contains("0.3"));
assert!(error_msg.contains("0.5"));
assert!(error_msg.contains("1.5"));
}
}