use biosphere::{MaxFeatures, RandomForestParameters};
#[derive(Clone)]
pub struct Control {
pub minimal_relative_segment_length: f64,
pub minimal_gain_to_split: Option<f64>,
pub model_selection_alpha: f64,
pub model_selection_n_permutations: usize,
pub number_of_wild_segments: usize,
pub seeded_segments_alpha: f64,
pub seed: u64,
pub random_forest_parameters: RandomForestParameters,
}
impl Control {
pub fn default() -> Control {
Control {
minimal_relative_segment_length: 0.01,
minimal_gain_to_split: None,
model_selection_alpha: 0.02,
model_selection_n_permutations: 199,
number_of_wild_segments: 100,
seeded_segments_alpha: std::f64::consts::FRAC_1_SQRT_2, seed: 0,
random_forest_parameters: RandomForestParameters::default()
.with_max_depth(Some(8))
.with_max_features(MaxFeatures::Sqrt)
.with_n_jobs(Some(-1)),
}
}
pub fn with_minimal_relative_segment_length(
mut self,
minimal_relative_segment_length: f64,
) -> Self {
if (minimal_relative_segment_length >= 0.5) | (minimal_relative_segment_length <= 0.) {
panic!(
"minimal_relative_segment_length needs to be strictly between 0 and 0.5 Got {}",
minimal_relative_segment_length
);
}
self.minimal_relative_segment_length = minimal_relative_segment_length;
self
}
pub fn with_minimal_gain_to_split(mut self, minimal_gain_to_split: Option<f64>) -> Self {
self.minimal_gain_to_split = minimal_gain_to_split;
self
}
pub fn with_model_selection_alpha(mut self, model_selection_alpha: f64) -> Self {
if (model_selection_alpha >= 1.) | (model_selection_alpha <= 0.) {
panic!(
"model_selection_alpha needs to be strictly between 0 and 1. Got {}",
model_selection_alpha
);
}
self.model_selection_alpha = model_selection_alpha;
self
}
pub fn with_model_selection_n_permutations(
mut self,
model_selection_n_permutations: usize,
) -> Self {
self.model_selection_n_permutations = model_selection_n_permutations;
self
}
pub fn with_number_of_wild_segments(mut self, number_of_wild_segments: usize) -> Self {
self.number_of_wild_segments = number_of_wild_segments;
self
}
pub fn with_seeded_segments_alpha(mut self, seeded_segments_alpha: f64) -> Self {
if (1. <= seeded_segments_alpha) | (seeded_segments_alpha <= 0.) {
panic!(
"seeded_segments_alpha needs to be strictly between 0 and 1. Got {}",
seeded_segments_alpha
);
}
self.seeded_segments_alpha = seeded_segments_alpha;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}
pub fn with_random_forest_parameters(
mut self,
random_forest_parameters: RandomForestParameters,
) -> Self {
self.random_forest_parameters = random_forest_parameters;
self
}
}