use super::termination_model_error::TerminationModelError;
use crate::util::duration_extension::DurationExtension;
use serde::Deserialize;
use std::time::{Duration, Instant};
#[derive(Debug, Deserialize)]
pub enum TerminationModel {
#[serde(rename = "query_runtime")]
QueryRuntimeLimit { limit: Duration, frequency: u64 },
#[serde(rename = "solution_size")]
SolutionSizeLimit { limit: usize },
#[serde(rename = "iterations")]
IterationsLimit { limit: u64 },
#[serde(rename = "combined")]
Combined { models: Vec<TerminationModel> },
}
impl TerminationModel {
pub fn terminate_search(
&self,
start_time: &Instant,
solution_size: usize,
iteration: u64,
) -> Result<bool, TerminationModelError> {
use TerminationModel as T;
match self {
T::QueryRuntimeLimit { limit, frequency } => {
if iteration % frequency == 0 {
let dur = Instant::now().duration_since(*start_time);
Ok(dur > *limit)
} else {
Ok(false)
}
}
T::SolutionSizeLimit { limit } => Ok(solution_size > *limit),
T::IterationsLimit { limit } => Ok(iteration + 1 > *limit),
T::Combined { models } => models.iter().try_fold(false, |acc, m| {
m.terminate_search(start_time, solution_size, iteration)
.map(|r| acc || r)
}),
}
}
pub fn explain_termination(
&self,
start_time: &Instant,
solution_size: usize,
iterations: u64,
) -> Option<String> {
use TerminationModel as T;
let caused_termination = self
.terminate_search(start_time, solution_size, iterations)
.unwrap_or(false);
match self {
T::Combined { models } => {
let combined_explanations: String = models
.iter()
.map(|m| m.explain_termination(start_time, solution_size, iterations))
.flatten()
.collect::<Vec<_>>()
.join(", ");
if combined_explanations.is_empty() {
None
} else {
Some(combined_explanations)
}
}
T::QueryRuntimeLimit { limit, .. } => {
if caused_termination {
Some(format!("exceeded runtime limit of {}", limit.hhmmss()))
} else {
None
}
}
T::SolutionSizeLimit { limit } => {
if caused_termination {
Some(format!("exceeded solution size limit of {}", limit))
} else {
None
}
}
T::IterationsLimit { limit } => {
if caused_termination {
Some(format!("exceeded iteration limit of {}", limit))
} else {
None
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::time::{Duration, Instant};
use super::TerminationModel as T;
#[test]
fn test_within_runtime_limit() {
let within_limit = Duration::from_secs(1);
let start_time = Instant::now() - within_limit;
let limit = Duration::from_secs(2);
let frequency = 10;
let m = T::QueryRuntimeLimit { limit, frequency };
for iteration in 0..(frequency + 1) {
let result = m.terminate_search(&start_time, 0, iteration).unwrap();
assert_eq!(result, false);
}
}
#[test]
fn test_exceeds_runtime_limit() {
let exceeds_limit = Duration::from_secs(3);
let start_time = Instant::now() - exceeds_limit;
let limit = Duration::from_secs(2);
let frequency = 10;
let m = T::QueryRuntimeLimit { limit, frequency };
for iteration in 0..(frequency + 1) {
let result = m.terminate_search(&start_time, 0, iteration).unwrap();
if iteration == 0 {
assert_eq!(result, true);
} else if iteration != frequency {
assert_eq!(result, false);
} else {
assert_eq!(result, true);
}
}
}
#[test]
fn test_iterations_limit() {
let m = T::IterationsLimit { limit: 5 };
let i = Instant::now();
let t_good = m.terminate_search(&i, 4, 4).unwrap();
let t_bad1 = m.terminate_search(&i, 5, 5).unwrap();
let t_bad2 = m.terminate_search(&i, 6, 6).unwrap();
assert_eq!(t_good, false);
assert_eq!(t_bad1, true);
assert_eq!(t_bad2, true);
}
#[test]
fn test_size_limit() {
let m = T::SolutionSizeLimit { limit: 5 };
let i = Instant::now();
let t_good = m.terminate_search(&i, 4, 4).unwrap();
let t_bad1 = m.terminate_search(&i, 5, 5).unwrap();
let t_bad2 = m.terminate_search(&i, 6, 6).unwrap();
assert_eq!(t_good, false);
assert_eq!(t_bad1, false);
assert_eq!(t_bad2, true);
}
#[test]
fn test_combined_3() {
let exceeds_limit = Duration::from_secs(3);
let start_time = Instant::now() - exceeds_limit;
let runtime_limit = Duration::from_secs(2);
let frequency = 1;
let iteration_limit = 5;
let solution_limit = 3;
let m1 = T::QueryRuntimeLimit {
limit: runtime_limit,
frequency,
};
let m2 = T::IterationsLimit {
limit: iteration_limit,
};
let m3 = T::SolutionSizeLimit {
limit: solution_limit,
};
let cm = T::Combined {
models: vec![m1, m2, m3],
};
let terminate = cm
.terminate_search(&start_time, solution_limit + 1, iteration_limit + 1)
.unwrap();
assert_eq!(terminate, true);
let msg = cm.explain_termination(&start_time, solution_limit + 1, iteration_limit + 1);
let expected = Some(
vec![
"exceeded runtime limit of 0:00:02.000",
"exceeded iteration limit of 5",
"exceeded solution size limit of 3",
]
.join(", "),
);
assert_eq!(msg, expected);
}
#[test]
fn test_combined_2_of_3() {
let exceeds_limit = Duration::from_secs(3);
let start_time = Instant::now() - exceeds_limit;
let runtime_limit = Duration::from_secs(2);
let frequency = 1;
let iteration_limit = 5;
let solution_limit = 3;
let m1 = T::QueryRuntimeLimit {
limit: runtime_limit,
frequency,
};
let m2 = T::IterationsLimit {
limit: iteration_limit,
};
let m3 = T::SolutionSizeLimit {
limit: solution_limit,
};
let cm = T::Combined {
models: vec![m1, m2, m3],
};
let terminate = cm
.terminate_search(&start_time, solution_limit - 1, iteration_limit + 1)
.unwrap();
assert_eq!(terminate, true);
let msg = cm.explain_termination(&start_time, solution_limit - 1, iteration_limit + 1);
let expected = Some(
vec![
"exceeded runtime limit of 0:00:02.000",
"exceeded iteration limit of 5",
]
.join(", "),
);
assert_eq!(msg, expected);
}
}