use anyhow::ensure;
use predicates::{Predicate, reflection::PredicateReflection};
use std::fmt::Display;
#[doc(hidden)]
#[derive(Debug)]
pub struct PredParams {
pub num_nodes: usize,
pub num_arcs: u64,
pub gain: f64,
pub avg_gain_impr: f64,
pub modified: usize,
pub update: usize,
}
#[derive(Debug, Clone)]
pub struct MaxUpdates {
max_updates: usize,
}
impl MaxUpdates {
pub const DEFAULT_MAX_UPDATES: usize = usize::MAX;
}
impl From<Option<usize>> for MaxUpdates {
fn from(max_updates: Option<usize>) -> Self {
match max_updates {
Some(max_updates) => MaxUpdates { max_updates },
None => Self::default(),
}
}
}
impl From<usize> for MaxUpdates {
fn from(max_updates: usize) -> Self {
Some(max_updates).into()
}
}
impl Default for MaxUpdates {
fn default() -> Self {
Self::from(Self::DEFAULT_MAX_UPDATES)
}
}
impl Display for MaxUpdates {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("(max updates: {})", self.max_updates))
}
}
impl PredicateReflection for MaxUpdates {}
impl Predicate<PredParams> for MaxUpdates {
fn eval(&self, pred_params: &PredParams) -> bool {
pred_params.update + 1 >= self.max_updates
}
}
#[derive(Debug, Clone)]
pub struct MinGain {
threshold: f64,
}
impl MinGain {
pub const DEFAULT_THRESHOLD: f64 = 0.001;
}
impl TryFrom<Option<f64>> for MinGain {
type Error = anyhow::Error;
fn try_from(threshold: Option<f64>) -> anyhow::Result<Self> {
Ok(match threshold {
Some(threshold) => {
ensure!(!threshold.is_nan());
ensure!(threshold >= 0.0, "The threshold must be nonnegative");
MinGain { threshold }
}
None => Self::default(),
})
}
}
impl TryFrom<f64> for MinGain {
type Error = anyhow::Error;
fn try_from(threshold: f64) -> anyhow::Result<Self> {
Some(threshold).try_into()
}
}
impl Default for MinGain {
fn default() -> Self {
Self::try_from(Self::DEFAULT_THRESHOLD).unwrap()
}
}
impl Display for MinGain {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("(min gain: {})", self.threshold))
}
}
impl PredicateReflection for MinGain {}
impl Predicate<PredParams> for MinGain {
fn eval(&self, pred_params: &PredParams) -> bool {
pred_params.gain <= self.threshold
}
}
#[derive(Debug, Clone)]
pub struct MinAvgImprov {
threshold: f64,
}
impl MinAvgImprov {
pub const DEFAULT_THRESHOLD: f64 = 0.1;
}
impl TryFrom<Option<f64>> for MinAvgImprov {
type Error = anyhow::Error;
fn try_from(threshold: Option<f64>) -> anyhow::Result<Self> {
Ok(match threshold {
Some(threshold) => {
ensure!(!threshold.is_nan());
MinAvgImprov { threshold }
}
None => Self::default(),
})
}
}
impl TryFrom<f64> for MinAvgImprov {
type Error = anyhow::Error;
fn try_from(threshold: f64) -> anyhow::Result<Self> {
Some(threshold).try_into()
}
}
impl Default for MinAvgImprov {
fn default() -> Self {
Self::try_from(Self::DEFAULT_THRESHOLD).unwrap()
}
}
impl Display for MinAvgImprov {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"(min avg gain improvement: {})",
self.threshold
))
}
}
impl PredicateReflection for MinAvgImprov {}
impl Predicate<PredParams> for MinAvgImprov {
fn eval(&self, pred_params: &PredParams) -> bool {
pred_params.avg_gain_impr <= self.threshold
}
}
#[derive(Debug, Clone, Default)]
pub struct MinModified {}
impl Display for MinModified {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("(min modified: √n)")
}
}
impl PredicateReflection for MinModified {}
impl Predicate<PredParams> for MinModified {
fn eval(&self, pred_params: &PredParams) -> bool {
(pred_params.modified as f64) <= (pred_params.num_nodes as f64).sqrt()
}
}
#[derive(Debug, Clone, Default)]
pub struct PercModified {
threshold: f64,
}
impl Display for PercModified {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("(min modified: {}%)", self.threshold * 100.0))
}
}
impl TryFrom<f64> for PercModified {
type Error = anyhow::Error;
fn try_from(threshold: f64) -> anyhow::Result<Self> {
ensure!(
threshold >= 0.0,
"The percent threshold must be nonnegative"
);
ensure!(
threshold <= 100.0,
"The percent threshold must be at most 100"
);
Ok(PercModified {
threshold: threshold / 100.0,
})
}
}
impl PredicateReflection for PercModified {}
impl Predicate<PredParams> for PercModified {
fn eval(&self, pred_params: &PredParams) -> bool {
(pred_params.modified as f64) <= (pred_params.num_nodes as f64) * self.threshold
}
}