1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
use kurobako_core::problem::ProblemSpec;
use kurobako_core::registry::FactoryRegistry;
use kurobako_core::rng::ArcRng;
use kurobako_core::solver::{
Capabilities, Solver, SolverFactory, SolverRecipe, SolverSpec, SolverSpecBuilder,
};
use kurobako_core::trial::{EvaluatedTrial, IdGen, NextTrial, Params};
use kurobako_core::{ErrorKind, Result};
use rand::distributions::Distribution as _;
use serde::{Deserialize, Serialize};
use structopt::StructOpt;
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_false(b: &bool) -> bool {
!b
}
#[derive(Debug, Clone, StructOpt, Serialize, Deserialize)]
pub struct RandomSolverRecipe {
#[structopt(long)]
#[serde(default, skip_serializing_if = "is_false")]
ask_all_steps: bool,
}
impl SolverRecipe for RandomSolverRecipe {
type Factory = RandomSolverFactory;
fn create_factory(&self, _registry: &FactoryRegistry) -> Result<Self::Factory> {
Ok(RandomSolverFactory {
ask_all_steps: self.ask_all_steps,
})
}
}
#[derive(Debug)]
pub struct RandomSolverFactory {
ask_all_steps: bool,
}
impl SolverFactory for RandomSolverFactory {
type Solver = RandomSolver;
fn specification(&self) -> Result<SolverSpec> {
let spec = SolverSpecBuilder::new("Random")
.attr(
"version",
&format!("kurobako_solvers={}", env!("CARGO_PKG_VERSION")),
)
.capabilities(Capabilities::all());
Ok(spec.finish())
}
fn create_solver(&self, rng: ArcRng, problem: &ProblemSpec) -> Result<Self::Solver> {
Ok(RandomSolver {
problem: problem.clone(),
rng,
current_step: if self.ask_all_steps { Some(0) } else { None },
})
}
}
#[derive(Debug)]
pub struct RandomSolver {
rng: ArcRng,
problem: ProblemSpec,
current_step: Option<u64>,
}
impl Solver for RandomSolver {
fn ask(&mut self, idg: &mut IdGen) -> Result<NextTrial> {
let mut params = Vec::new();
for p in self.problem.params_domain.variables() {
let param = p.sample(&mut self.rng);
params.push(param);
}
let next_step = if let Some(current_step) = self.current_step {
let step = self.problem.steps.iter().find(|&s| s > current_step);
track_assert_some!(step, ErrorKind::Bug)
} else {
self.problem.steps.last()
};
Ok(NextTrial {
id: idg.generate(),
params: Params::new(params),
next_step: Some(next_step),
})
}
fn tell(&mut self, trial: EvaluatedTrial) -> Result<()> {
if let Some(step) = &mut self.current_step {
*step = trial.current_step;
}
Ok(())
}
}