use crate::domain::{Distribution, Domain, Range, VariableBuilder};
use crate::registry::FactoryRegistry;
use crate::rng::ArcRng;
use crate::solver::{Capabilities, Capability};
use crate::trial::{Params, Values};
use crate::{ErrorKind, Result};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt;
use structopt::StructOpt;
#[derive(Debug)]
pub struct ProblemSpecBuilder {
name: String,
attrs: BTreeMap<String, String>,
params: Vec<VariableBuilder>,
values: Vec<VariableBuilder>,
steps: Vec<u64>,
}
impl ProblemSpecBuilder {
pub fn new(problem_name: &str) -> Self {
Self {
name: problem_name.to_owned(),
attrs: BTreeMap::new(),
params: Vec::new(),
values: Vec::new(),
steps: vec![1],
}
}
pub fn attr(mut self, key: &str, value: &str) -> Self {
self.attrs.insert(key.to_owned(), value.to_owned());
self
}
pub fn param(mut self, var: VariableBuilder) -> Self {
self.params.push(var);
self
}
pub fn params(mut self, vars: Vec<VariableBuilder>) -> Self {
self.params = vars;
self
}
pub fn value(mut self, var: VariableBuilder) -> Self {
self.values.push(var);
self
}
pub fn steps<I>(mut self, steps: I) -> Self
where
I: IntoIterator<Item = u64>,
{
self.steps = steps.into_iter().collect();
self
}
pub fn finish(self) -> Result<ProblemSpec> {
let params_domain = track!(Domain::new(self.params))?;
let values_domain = track!(Domain::new(self.values))?;
let steps = track!(EvaluableSteps::new(self.steps))?;
Ok(ProblemSpec {
name: self.name,
attrs: self.attrs,
params_domain,
values_domain,
steps,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ProblemSpec {
pub name: String,
#[serde(default)]
pub attrs: BTreeMap<String, String>,
pub params_domain: Domain,
pub values_domain: Domain,
pub steps: EvaluableSteps,
}
impl ProblemSpec {
pub fn requirements(&self) -> Capabilities {
let mut c = Capabilities::empty();
if self.values_domain.variables().len() > 1 {
c.add_capability(Capability::MultiObjective);
}
for v in self.params_domain.variables() {
if v.constraint().is_some() {
c.add_capability(Capability::Conditional);
}
match (v.range(), v.distribution()) {
(Range::Continuous { .. }, Distribution::Uniform) => {
c.add_capability(Capability::UniformContinuous);
}
(Range::Continuous { .. }, Distribution::LogUniform) => {
c.add_capability(Capability::LogUniformContinuous);
}
(Range::Discrete { .. }, Distribution::Uniform) => {
c.add_capability(Capability::UniformDiscrete);
}
(Range::Discrete { .. }, Distribution::LogUniform) => {
c.add_capability(Capability::LogUniformDiscrete);
}
(Range::Categorical { .. }, _) => {
c.add_capability(Capability::Categorical);
}
}
}
c
}
}
pub trait ProblemRecipe: Clone + Send + StructOpt + Serialize + for<'a> Deserialize<'a> {
type Factory: ProblemFactory;
fn create_factory(&self, registry: &FactoryRegistry) -> Result<Self::Factory>;
}
pub trait ProblemFactory: Send {
type Problem: Problem;
fn specification(&self) -> Result<ProblemSpec>;
fn create_problem(&self, rng: ArcRng) -> Result<Self::Problem>;
}
enum ProblemFactoryCall {
Specification,
CreateProblem(ArcRng),
}
enum ProblemFactoryReturn {
Specification(ProblemSpec),
CreateProblem(BoxProblem),
}
pub struct BoxProblemFactory(
Box<dyn Fn(ProblemFactoryCall) -> Result<ProblemFactoryReturn> + Send>,
);
impl BoxProblemFactory {
pub fn new<T>(problem: T) -> Self
where
T: 'static + ProblemFactory,
{
Self(Box::new(move |call| match call {
ProblemFactoryCall::Specification => problem
.specification()
.map(ProblemFactoryReturn::Specification),
ProblemFactoryCall::CreateProblem(rng) => problem
.create_problem(rng)
.map(BoxProblem::new)
.map(ProblemFactoryReturn::CreateProblem),
}))
}
}
impl ProblemFactory for BoxProblemFactory {
type Problem = BoxProblem;
fn specification(&self) -> Result<ProblemSpec> {
let v = track!((self.0)(ProblemFactoryCall::Specification))?;
if let ProblemFactoryReturn::Specification(v) = v {
Ok(v)
} else {
unreachable!()
}
}
fn create_problem(&self, rng: ArcRng) -> Result<Self::Problem> {
let v = track!((self.0)(ProblemFactoryCall::CreateProblem(rng)))?;
if let ProblemFactoryReturn::CreateProblem(v) = v {
Ok(v)
} else {
unreachable!()
}
}
}
impl fmt::Debug for BoxProblemFactory {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "BoxProblemFactory {{ .. }}")
}
}
pub trait Problem: Send {
type Evaluator: Evaluator;
fn create_evaluator(&self, params: Params) -> Result<Self::Evaluator>;
}
pub struct BoxProblem(Box<dyn Fn(Params) -> Result<BoxEvaluator> + Send>);
impl BoxProblem {
pub fn new<T>(problem: T) -> Self
where
T: 'static + Problem,
{
Self(Box::new(move |params| {
problem.create_evaluator(params).map(BoxEvaluator::new)
}))
}
}
impl Problem for BoxProblem {
type Evaluator = BoxEvaluator;
fn create_evaluator(&self, params: Params) -> Result<Self::Evaluator> {
track!((self.0)(params))
}
}
impl fmt::Debug for BoxProblem {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "BoxProblem {{ .. }}")
}
}
pub trait Evaluator: Send {
fn evaluate(&mut self, next_step: u64) -> Result<(u64, Values)>;
}
impl<T: Evaluator + ?Sized> Evaluator for Box<T> {
fn evaluate(&mut self, next_step: u64) -> Result<(u64, Values)> {
(**self).evaluate(next_step)
}
}
pub struct BoxEvaluator(Box<dyn Evaluator>);
impl BoxEvaluator {
pub fn new<T>(evaluator: T) -> Self
where
T: 'static + Evaluator,
{
Self(Box::new(evaluator))
}
}
impl Evaluator for BoxEvaluator {
fn evaluate(&mut self, next_step: u64) -> Result<(u64, Values)> {
self.0.evaluate(next_step)
}
}
impl fmt::Debug for BoxEvaluator {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "BoxEvaluator {{ .. }}")
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EvaluableSteps(EvaluableStepsInner);
impl EvaluableSteps {
pub fn new(steps: Vec<u64>) -> Result<Self> {
track!(EvaluableStepsInner::new(steps)).map(Self)
}
pub fn last(&self) -> u64 {
self.0.last()
}
pub fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = u64> {
self.0.iter()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(untagged)]
enum EvaluableStepsInner {
Max(u64),
Steps(Vec<u64>),
}
impl EvaluableStepsInner {
fn new(steps: Vec<u64>) -> Result<Self> {
track_assert!(!steps.is_empty(), ErrorKind::InvalidInput);
track_assert!(steps[0] > 0, ErrorKind::InvalidInput);
for (a, b) in steps.iter().zip(steps.iter().skip(1)) {
track_assert!(a < b, ErrorKind::InvalidInput);
}
let last = steps[steps.len() - 1];
if last == steps.len() as u64 {
Ok(Self::Max(last))
} else {
Ok(Self::Steps(steps))
}
}
fn last(&self) -> u64 {
match self {
Self::Max(n) => *n,
Self::Steps(ns) => ns[ns.len() - 1],
}
}
fn iter<'a>(&'a self) -> impl 'a + Iterator<Item = u64> {
match self {
Self::Max(n) => itertools::Either::Left(1..=*n),
Self::Steps(ns) => itertools::Either::Right(ns.iter().copied()),
}
}
}