use std::ops::RangeInclusive;
use std::time::Duration;
use super::{RestartStrategy, Restarter};
use crate::Mode;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InvalidRestartOptionsError {
Dimensions,
SearchRange,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InvalidRestartStrategyOptionsError {
InitialStepSize,
PopulationSize,
}
#[derive(Clone, Debug)]
pub struct RestartOptions {
pub strategy: RestartStrategy,
pub dimensions: usize,
pub mode: Mode,
pub parallel_update: bool,
pub search_range: RangeInclusive<f64>,
pub fun_target: Option<f64>,
pub max_function_evals: Option<usize>,
pub max_time: Option<Duration>,
pub max_function_evals_per_run: Option<usize>,
pub max_generations_per_run: Option<usize>,
pub enable_printing: bool,
pub seed: Option<u64>,
}
impl RestartOptions {
pub fn new(
dimensions: usize,
mut search_range: RangeInclusive<f64>,
strategy: RestartStrategy,
) -> Self {
if search_range.is_empty() {
search_range = *search_range.end()..=*search_range.start();
}
Self {
strategy,
dimensions,
mode: Mode::Minimize,
parallel_update: false,
search_range,
fun_target: None,
max_function_evals: None,
max_generations_per_run: None,
max_time: None,
max_function_evals_per_run: None,
enable_printing: false,
seed: None,
}
}
pub fn mode(mut self, mode: Mode) -> Self {
self.mode = mode;
self
}
pub fn parallel_update(mut self, parallel_update: bool) -> Self {
self.parallel_update = parallel_update;
self
}
pub fn fun_target(mut self, fun_target: f64) -> Self {
self.fun_target = Some(fun_target);
self
}
pub fn max_function_evals(mut self, function_evals: usize) -> Self {
self.max_function_evals = Some(function_evals);
self
}
pub fn max_time(mut self, max_time: Duration) -> Self {
self.max_time = Some(max_time);
self
}
pub fn max_function_evals_per_run(mut self, function_evals: usize) -> Self {
self.max_function_evals_per_run = Some(function_evals);
self
}
pub fn max_generations_per_run(mut self, generations: usize) -> Self {
self.max_generations_per_run = Some(generations);
self
}
pub fn enable_printing(mut self, enable_printing: bool) -> Self {
self.enable_printing = enable_printing;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn build(self) -> Result<Restarter, InvalidRestartOptionsError> {
Restarter::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::restart::Local;
#[test]
fn test_build() {
assert!(RestartOptions::new(
2,
0.0..=1.0,
RestartStrategy::Local(Local::new(10, None).unwrap())
)
.build()
.is_ok());
assert!(matches!(
RestartOptions::new(
0,
0.0..=1.0,
RestartStrategy::Local(Local::new(10, None).unwrap())
)
.build(),
Err(InvalidRestartOptionsError::Dimensions)
));
assert!(matches!(
RestartOptions::new(
2,
2.0..=2.0,
RestartStrategy::Local(Local::new(10, None).unwrap())
)
.build(),
Err(InvalidRestartOptionsError::SearchRange)
));
}
}