1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use crate::filter::KurobakoFilterRecipe;
use kurobako_core::epi;
use kurobako_core::filter::{BoxFilter, Filter as _, FilterRecipe as _};
use kurobako_core::json;
use kurobako_core::problem::ProblemSpec;
use kurobako_core::solver::{
    BoxSolver, BoxSolverRecipe, ObservedObs, Solver, SolverRecipe, SolverSpec, UnobservedObs,
};
use kurobako_core::{Error, Result};
use kurobako_solvers::{asha, optuna, random};
use rand::{self, Rng};
use serde::{Deserialize, Serialize};
use serde_json;
use structopt::StructOpt;
use yamakan::observation::IdGen;

#[derive(Debug, Clone, StructOpt, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[structopt(rename_all = "kebab-case")]
pub struct KurobakoSolverRecipe {
    #[structopt(long)]
    #[serde(default, skip_serializing_if = "Option::is_none")]
    tag: Option<String>,

    #[structopt(long, parse(try_from_str = "json::parse_json"))]
    #[serde(default, skip_serializing_if = "Vec::is_empty")]
    filters: Vec<KurobakoFilterRecipe>,

    #[structopt(long)]
    #[serde(default, skip_serializing)]
    filters_end: bool,

    #[structopt(flatten)]
    #[serde(flatten)]
    inner: InnerRecipe,
}
impl SolverRecipe for KurobakoSolverRecipe {
    type Solver = KurobakoSolver;

    fn create_solver(&self, mut problem: ProblemSpec) -> Result<Self::Solver> {
        let mut filters = self
            .filters
            .iter()
            .map(|r| track!(r.create_filter()))
            .collect::<Result<Vec<_>>>()?;
        for f in &mut filters {
            track!(f.filter_problem_spec(&mut problem))?;
        }

        let inner = track!(self.inner.create_solver(problem))?;
        Ok(KurobakoSolver {
            tag: self.tag.clone(),
            inner,
            filters,
        })
    }
}

#[derive(Debug, Clone, StructOpt, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[structopt(rename_all = "kebab-case")]
enum InnerRecipe {
    Random(random::RandomSolverRecipe),
    Optuna(optuna::OptunaSolverRecipe),
    Asha(asha::AshaSolverRecipe),
    Command(epi::solver::ExternalProgramSolverRecipe),
}
impl SolverRecipe for InnerRecipe {
    type Solver = BoxSolver;

    fn create_solver(&self, problem: ProblemSpec) -> Result<Self::Solver> {
        match self {
            InnerRecipe::Random(r) => track!(r.create_solver(problem)).map(BoxSolver::new),
            InnerRecipe::Optuna(r) => track!(r.create_solver(problem)).map(BoxSolver::new),
            InnerRecipe::Asha(r) => {
                let mut r = r.clone();
                track!(r.base_solver.set_recipe(|json| {
                    let recipe: KurobakoSolverRecipe =
                        track!(serde_json::from_value(json.get().clone()).map_err(Error::from))?;
                    Ok(BoxSolverRecipe::new(recipe))
                }))?;
                track!(r.create_solver(problem)).map(BoxSolver::new)
            }
            InnerRecipe::Command(r) => track!(r.create_solver(problem)).map(BoxSolver::new),
        }
    }
}

#[derive(Debug)]
pub struct KurobakoSolver {
    tag: Option<String>,
    filters: Vec<BoxFilter>,
    inner: BoxSolver,
}
impl Solver for KurobakoSolver {
    fn specification(&self) -> SolverSpec {
        let mut spec = self.inner.specification();
        if let Some(tag) = &self.tag {
            spec.name.push_str(&format!("#{}", tag));
        }
        spec
    }

    fn ask<R: Rng, G: IdGen>(&mut self, rng: &mut R, idg: &mut G) -> Result<UnobservedObs> {
        let mut obs = track!(self.inner.ask(rng, idg))?;
        for f in &mut self.filters {
            track!(f.filter_ask(rng, &mut obs))?;
        }
        Ok(obs)
    }

    fn tell(&mut self, mut obs: ObservedObs) -> Result<()> {
        let mut rng = rand::thread_rng(); // TODO
        for f in &mut self.filters {
            track!(f.filter_tell(&mut rng, &mut obs))?;
        }
        track!(self.inner.tell(obs))
    }
}