use std::ops::RangeInclusive;
use crate::capture::Capture;
use crate::comb::{count_permutations, pick};
#[derive(Clone, Debug)]
pub struct UnivariateDescentConfig {
pub init_value: f64,
pub init_step: f64,
pub min_step: f64,
pub max_steps: u64,
pub acceptable_residual: f64,
}
#[derive(Debug)]
pub struct UnivariateDescentOutcome {
pub steps: u64,
pub optimal_value: f64,
pub optimal_residual: f64,
}
pub fn univariate_descent(
config: &UnivariateDescentConfig,
mut loss_f: impl FnMut(f64) -> f64,
) -> UnivariateDescentOutcome {
let mut steps = 0;
let mut residual = loss_f(config.init_value);
if residual <= config.acceptable_residual {
return UnivariateDescentOutcome {
steps: 0,
optimal_value: config.init_value,
optimal_residual: residual
};
}
let (mut value, mut step) = (config.init_value, config.init_step);
let (mut optimal_value, mut optimal_residual) = (value, residual);
while steps < config.max_steps {
steps += 1;
let new_value = value + step; let new_residual = loss_f(new_value);
if new_residual > residual {
step = -step * 0.5;
if step.abs() < config.min_step {
break;
}
} else if new_residual < optimal_residual {
optimal_residual = new_residual;
optimal_value = new_value;
if optimal_residual <= config.acceptable_residual {
break;
}
}
residual = new_residual;
value = new_value;
}
UnivariateDescentOutcome {
steps,
optimal_value,
optimal_residual,
}
}
#[derive(Clone, Debug)]
pub struct HypergridSearchConfig<'a> {
pub max_steps: u64,
pub acceptable_residual: f64,
pub bounds: Capture<'a, Vec<RangeInclusive<f64>>, [RangeInclusive<f64>]>,
pub resolution: usize,
}
#[derive(Debug)]
pub struct HypergridSearchOutcome {
pub steps: u64,
pub optimal_values: Vec<f64>,
pub optimal_residual: f64,
}
pub fn hypergrid_search(
config: &HypergridSearchConfig,
mut constraint_f: impl FnMut(&[f64]) -> bool,
mut loss_f: impl FnMut(&[f64]) -> f64) -> HypergridSearchOutcome {
let mut steps = 0;
let mut values = Vec::with_capacity(config.bounds.len());
values.resize(values.capacity(), 0.0);
let mut optimal_values = values.clone();
let mut optimal_residual = f64::MAX;
let cardinalities = {
let mut cardinalities = Vec::with_capacity(values.len());
cardinalities.resize(cardinalities.capacity(), config.resolution);
cardinalities
};
let mut ordinals = cardinalities.clone();
let permutations = count_permutations(&cardinalities);
let mut bounds = (*config.bounds).to_vec();
let inv_resolution = 1.0 / (config.resolution - 1) as f64;
'outer: while steps < config.max_steps {
steps += 1;
for permutation in 0.. permutations {
pick(&cardinalities, permutation, &mut ordinals);
for (dimension, &ordinal) in ordinals.iter().enumerate() {
let bound = &bounds[dimension];
let range = bound.end() - bound.start();
values[dimension] = bound.start() + ordinal as f64 * range * inv_resolution;
if constraint_f(&values) {
let residual = loss_f(&values);
if residual < optimal_residual {
optimal_residual = residual;
optimal_values.copy_from_slice(&values);
if residual <= config.acceptable_residual {
break 'outer;
}
}
}
}
}
for (dimension, &value) in optimal_values.iter().enumerate() {
let hard_bound = &config.bounds[dimension];
let bound = &mut bounds[dimension];
let new_range = (bound.end() - bound.start()) / config.resolution as f64;
let new_start = f64::max(*hard_bound.start(), value - new_range / 2.0);
let new_end = f64::min(new_start + new_range, *hard_bound.end());
*bound = new_start..=new_end;
}
}
HypergridSearchOutcome {
steps,
optimal_values,
optimal_residual,
}
}
#[cfg(test)]
mod tests;