1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
use kurobako_core::num::FiniteF64;
use kurobako_core::parameter::{Distribution, ParamDomain, ParamValue};
use kurobako_core::problem::ProblemSpec;
use kurobako_core::solver::{
    ObservedObs, Solver, SolverCapabilities, SolverRecipe, SolverSpec, UnobservedObs,
};
use kurobako_core::{ErrorKind, Result};
use rand::Rng;
use serde::{Deserialize, Serialize};
use structopt::StructOpt;
use yamakan::budget::{Budget, Budgeted};
use yamakan::observation::{IdGen, Obs};

#[derive(Debug, Clone, StructOpt, Serialize, Deserialize)]
pub struct RandomSolverRecipe {}
impl SolverRecipe for RandomSolverRecipe {
    type Solver = RandomSolver;

    fn create_solver(&self, problem: ProblemSpec) -> Result<Self::Solver> {
        Ok(RandomSolver {
            params_domain: problem.params_domain,
            budget: Budget::new(problem.evaluation_expense.get()),
        })
    }
}

#[derive(Debug)]
pub struct RandomSolver {
    params_domain: Vec<ParamDomain>,
    budget: Budget,
}
impl RandomSolver {
    // TODO: delete
    pub fn new(params_domain: Vec<ParamDomain>) -> Self {
        Self {
            params_domain,
            budget: Budget::new(1),
        }
    }
}
impl Solver for RandomSolver {
    fn specification(&self) -> SolverSpec {
        SolverSpec {
            name: "random".to_owned(),
            version: Some(env!("CARGO_PKG_VERSION").to_owned()),
            capabilities: SolverCapabilities::empty()
                .categorical()
                .discrete()
                .multi_objective(),
        }
    }

    fn ask<R: Rng, G: IdGen>(&mut self, rng: &mut R, idg: &mut G) -> Result<UnobservedObs> {
        let mut params = Vec::new();
        for p in &self.params_domain {
            let v = match p {
                ParamDomain::Categorical(p) => {
                    ParamValue::Categorical(rng.gen_range(0, p.choices.len()))
                }
                ParamDomain::Conditional(_) => {
                    track_panic!(ErrorKind::Incapable);
                }
                ParamDomain::Continuous(p) => {
                    track_assert_eq!(p.distribution, Distribution::Uniform, ErrorKind::Incapable);

                    let n = rng.gen_range(p.range.low.get(), p.range.high.get());
                    ParamValue::Continuous(unsafe { FiniteF64::new_unchecked(n) })
                }
                ParamDomain::Discrete(p) => {
                    ParamValue::Discrete(rng.gen_range(p.range.low, p.range.high))
                }
            };
            params.push(v);
        }
        let obs = track!(Obs::new(idg, Budgeted::new(self.budget, params)))?;
        Ok(obs)
    }

    fn tell(&mut self, _obs: ObservedObs) -> Result<()> {
        Ok(())
    }
}