1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
//! A solver based on random search.
use kurobako_core::problem::ProblemSpec;
use kurobako_core::registry::FactoryRegistry;
use kurobako_core::rng::ArcRng;
use kurobako_core::solver::{
    Capabilities, Solver, SolverFactory, SolverRecipe, SolverSpec, SolverSpecBuilder,
};
use kurobako_core::trial::{EvaluatedTrial, IdGen, NextTrial, Params};
use kurobako_core::{ErrorKind, Result};
use rand::distributions::Distribution as _;
use serde::{Deserialize, Serialize};
use structopt::StructOpt;

#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_false(b: &bool) -> bool {
    !b
}

/// Recipe of `RandomSolver`.
#[derive(Debug, Clone, StructOpt, Serialize, Deserialize)]
pub struct RandomSolverRecipe {
    /// If this flag is set, this solver asks evaluators to evaluate parameters at every intermediate step.
    #[structopt(long)]
    #[serde(default, skip_serializing_if = "is_false")]
    ask_all_steps: bool,
}
impl SolverRecipe for RandomSolverRecipe {
    type Factory = RandomSolverFactory;

    fn create_factory(&self, _registry: &FactoryRegistry) -> Result<Self::Factory> {
        Ok(RandomSolverFactory {
            ask_all_steps: self.ask_all_steps,
        })
    }
}

/// Factory of `RandomSolver`.
#[derive(Debug)]
pub struct RandomSolverFactory {
    ask_all_steps: bool,
}
impl SolverFactory for RandomSolverFactory {
    type Solver = RandomSolver;

    fn specification(&self) -> Result<SolverSpec> {
        let spec = SolverSpecBuilder::new("Random")
            .attr(
                "version",
                &format!("kurobako_solvers={}", env!("CARGO_PKG_VERSION")),
            )
            .capabilities(Capabilities::all());
        Ok(spec.finish())
    }

    fn create_solver(&self, rng: ArcRng, problem: &ProblemSpec) -> Result<Self::Solver> {
        Ok(RandomSolver {
            problem: problem.clone(),
            rng,
            current_step: if self.ask_all_steps { Some(0) } else { None },
        })
    }
}

/// Solver based on random search.
#[derive(Debug)]
pub struct RandomSolver {
    rng: ArcRng,
    problem: ProblemSpec,
    current_step: Option<u64>,
}
impl Solver for RandomSolver {
    fn ask(&mut self, idg: &mut IdGen) -> Result<NextTrial> {
        let mut params = Vec::new();
        for p in self.problem.params_domain.variables() {
            let param = p.sample(&mut self.rng);
            params.push(param);
        }

        let next_step = if let Some(current_step) = self.current_step {
            let step = self.problem.steps.iter().find(|&s| s > current_step);
            track_assert_some!(step, ErrorKind::Bug)
        } else {
            self.problem.steps.last()
        };
        Ok(NextTrial {
            id: idg.generate(),
            params: Params::new(params),
            next_step: Some(next_step),
        })
    }

    fn tell(&mut self, trial: EvaluatedTrial) -> Result<()> {
        if let Some(step) = &mut self.current_step {
            *step = trial.current_step;
        }
        Ok(())
    }
}