1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// https://github.com/automl/nas_benchmarks/blob/master/tabular_benchmarks/fcnet_benchmark.py
use hdf5file::{self, DataObject, Hdf5File};
use kurobako_core::parameter::{choices, int, ParamValue};
use kurobako_core::problem::{
    Evaluate, EvaluatorCapability, Problem, ProblemRecipe, ProblemSpec, Values,
};
use kurobako_core::{Error, ErrorKind, Result};
use rustats::num::FiniteF64;
use rustats::range::MinMax;
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::num::NonZeroU64;
use std::path::PathBuf;
use std::rc::Rc;
use structopt::StructOpt;
use yamakan::budget::Budget;
use yamakan::observation::ObsId;

fn into_error(e: hdf5file::Error) -> Error {
    use trackable::error::ErrorKindExt as _;

    ErrorKind::Other.takes_over(e).into()
}

#[derive(Debug, Clone, StructOpt, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
#[structopt(rename_all = "kebab-case")]
pub struct FcNetProblemRecipe {
    pub dataset_path: PathBuf,
}
impl ProblemRecipe for FcNetProblemRecipe {
    type Problem = FcNetProblem;

    fn create_problem(&self) -> Result<Self::Problem> {
        let file = track!(Hdf5File::open_file(&self.dataset_path).map_err(into_error))?;
        Ok(FcNetProblem {
            file: Rc::new(RefCell::new(file)),
            name: track_assert_some!(
                self.dataset_path.file_stem().and_then(|n| n.to_str()),
                ErrorKind::InvalidInput
            )
            .to_owned(),
        })
    }
}

#[derive(Debug)]
pub struct FcNetProblem {
    file: Rc<RefCell<Hdf5File>>,
    name: String,
}
impl Problem for FcNetProblem {
    type Evaluator = FcNetEvaluator;

    fn specification(&self) -> ProblemSpec {
        let params_domain = vec![
            choices("activation_fn_1", &["tanh", "relu"]),
            choices("activation_fn_2", &["tanh", "relu"]),
            int("batch_size", 0, 4).unwrap(),
            int("dropout_1", 0, 3).unwrap(),
            int("dropout_2", 0, 3).unwrap(),
            int("init_lr", 0, 6).unwrap(),
            choices("lr_schedule", &["cosine", "const"]),
            int("n_units_1", 0, 6).unwrap(),
            int("n_units_2", 0, 6).unwrap(),
        ];

        ProblemSpec {
            name: self.name.clone(),
            version: None,
            params_domain,
            values_domain: unsafe {
                vec![MinMax::new_unchecked(
                    FiniteF64::new_unchecked(0.0),
                    FiniteF64::new_unchecked(1.0),
                )]
            },
            evaluation_expense: unsafe { NonZeroU64::new_unchecked(100) },
            capabilities: vec![EvaluatorCapability::Concurrent].into_iter().collect(),
        }
    }

    fn create_evaluator(&mut self, id: ObsId) -> Result<Self::Evaluator> {
        Ok(FcNetEvaluator {
            file: self.file.clone(),
            sample_index: id.get() as usize % 4,
        })
    }
}

#[derive(Debug)]
pub struct FcNetEvaluator {
    file: Rc<RefCell<Hdf5File>>,
    sample_index: usize,
}
impl Evaluate for FcNetEvaluator {
    fn evaluate(&mut self, params: &[ParamValue], budget: &mut Budget) -> Result<Values> {
        const UNITS: [usize; 6] = [16, 32, 64, 128, 256, 512];
        const DROPOUTS: [&str; 3] = ["0.0", "0.3", "0.6"];

        fn index(p: &ParamValue) -> usize {
            p.as_discrete().unwrap() as usize
        }

        let key = format!(
            r#"{{"activation_fn_1": {:?}, "activation_fn_2": {:?}, "batch_size": {}, "dropout_1": {}, "dropout_2": {}, "init_lr": {}, "lr_schedule": {:?}, "n_units_1": {}, "n_units_2": {}}}"#,
            (["tanh", "relu"])[params[0].as_categorical().unwrap()],
            (["tanh", "relu"])[params[1].as_categorical().unwrap()],
            ([8, 16, 32, 64])[index(&params[2])],
            DROPOUTS[index(&params[3])],
            DROPOUTS[index(&params[4])],
            ([5.0 * 1e-4, 1e-3, 5.0 * 1e-3, 1e-2, 5.0 * 1e-2, 1e-1])[index(&params[5])],
            (["cosine", "const"])[params[6].as_categorical().unwrap()],
            UNITS[index(&params[7])],
            UNITS[index(&params[8])]
        );

        let data = track!(self
            .file
            .borrow_mut()
            .get_object(format!("/{}/valid_mse", key))
            .map_err(into_error))?;
        let DataObject::Float(data) = track_assert_some!(data, ErrorKind::InvalidInput; key);

        let value = data[[self.sample_index, budget.amount as usize - 1]];
        budget.consumption = budget.amount;
        Ok(vec![FiniteF64::new(value)?])
    }
}