use std::collections::{HashMap, HashSet, VecDeque};
use crate::distribution::Distribution;
use crate::parameter::ParamId;
use crate::sampler::CompletedTrial;
#[derive(Debug, Clone, Default)]
pub struct IntersectionSearchSpace;
impl IntersectionSearchSpace {
#[must_use]
pub fn calculate(trials: &[CompletedTrial]) -> HashMap<ParamId, Distribution> {
if trials.is_empty() {
return HashMap::new();
}
let first_trial = &trials[0];
let mut candidate_params: HashSet<ParamId> =
first_trial.distributions.keys().copied().collect();
for trial in trials.iter().skip(1) {
let trial_params: HashSet<ParamId> = trial.distributions.keys().copied().collect();
candidate_params.retain(|param| trial_params.contains(param));
}
let mut result = HashMap::new();
for param_id in candidate_params {
for trial in trials {
if let Some(dist) = trial.distributions.get(¶m_id) {
result.insert(param_id, dist.clone());
break;
}
}
}
result
}
}
#[derive(Debug, Clone, Default)]
pub struct GroupDecomposedSearchSpace;
impl GroupDecomposedSearchSpace {
#[must_use]
pub fn calculate(trials: &[CompletedTrial]) -> Vec<HashSet<ParamId>> {
if trials.is_empty() {
return Vec::new();
}
let mut all_params: HashSet<ParamId> = HashSet::new();
for trial in trials {
for ¶m_id in trial.distributions.keys() {
all_params.insert(param_id);
}
}
if all_params.is_empty() {
return Vec::new();
}
let mut adjacency: HashMap<ParamId, HashSet<ParamId>> = HashMap::new();
for ¶m in &all_params {
adjacency.insert(param, HashSet::new());
}
for trial in trials {
let trial_params: Vec<ParamId> = trial.distributions.keys().copied().collect();
for (i, ¶m1) in trial_params.iter().enumerate() {
for ¶m2 in trial_params.iter().skip(i + 1) {
adjacency
.get_mut(¶m1)
.expect("param should exist in adjacency map")
.insert(param2);
adjacency
.get_mut(¶m2)
.expect("param should exist in adjacency map")
.insert(param1);
}
}
}
let mut visited: HashSet<ParamId> = HashSet::new();
let mut groups: Vec<HashSet<ParamId>> = Vec::new();
for ¶m in &all_params {
if visited.contains(¶m) {
continue;
}
let mut component: HashSet<ParamId> = HashSet::new();
let mut queue: VecDeque<ParamId> = VecDeque::new();
queue.push_back(param);
visited.insert(param);
while let Some(current) = queue.pop_front() {
component.insert(current);
if let Some(neighbors) = adjacency.get(¤t) {
for &neighbor in neighbors {
if !visited.contains(&neighbor) {
visited.insert(neighbor);
queue.push_back(neighbor);
}
}
}
}
groups.push(component);
}
groups
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distribution::{CategoricalDistribution, FloatDistribution, IntDistribution};
use crate::param::ParamValue;
use crate::parameter::ParamId;
fn create_trial(
id: u64,
params: Vec<(ParamId, ParamValue, Distribution)>,
value: f64,
) -> CompletedTrial {
let mut param_map = HashMap::new();
let mut dist_map = HashMap::new();
for (param_id, pv, dist) in params {
param_map.insert(param_id, pv);
dist_map.insert(param_id, dist);
}
CompletedTrial::new(id, param_map, dist_map, HashMap::new(), value)
}
#[test]
fn test_empty_trials() {
let trials: Vec<CompletedTrial> = vec![];
let result = IntersectionSearchSpace::calculate(&trials);
assert!(result.is_empty());
}
#[test]
fn test_single_trial() {
let x_id = ParamId::new();
let y_id = ParamId::new();
let dist_x = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let dist_y = Distribution::Int(IntDistribution {
low: 1,
high: 10,
log_scale: false,
step: None,
});
let trial = create_trial(
0,
vec![
(x_id, ParamValue::Float(0.5), dist_x.clone()),
(y_id, ParamValue::Int(5), dist_y.clone()),
],
1.0,
);
let result = IntersectionSearchSpace::calculate(&[trial]);
assert_eq!(result.len(), 2);
assert!(result.contains_key(&x_id));
assert!(result.contains_key(&y_id));
assert_eq!(result.get(&x_id), Some(&dist_x));
assert_eq!(result.get(&y_id), Some(&dist_y));
}
#[test]
fn test_all_trials_same_params() {
let x_id = ParamId::new();
let y_id = ParamId::new();
let dist_x = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let dist_y = Distribution::Float(FloatDistribution {
low: -1.0,
high: 1.0,
log_scale: false,
step: None,
});
let trials: Vec<CompletedTrial> = (0..5)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
let val = i as f64 * 0.1;
create_trial(
i,
vec![
(x_id, ParamValue::Float(val), dist_x.clone()),
(y_id, ParamValue::Float(val - 0.5), dist_y.clone()),
],
val * val,
)
})
.collect();
let result = IntersectionSearchSpace::calculate(&trials);
assert_eq!(result.len(), 2);
assert!(result.contains_key(&x_id));
assert!(result.contains_key(&y_id));
}
#[test]
fn test_partial_overlap() {
let x_id = ParamId::new();
let y_id = ParamId::new();
let z_id = ParamId::new();
let dist_x = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let dist_y = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let dist_z = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let trial1 = create_trial(
0,
vec![
(x_id, ParamValue::Float(0.5), dist_x.clone()),
(y_id, ParamValue::Float(0.3), dist_y.clone()),
],
1.0,
);
let trial2 = create_trial(
1,
vec![
(x_id, ParamValue::Float(0.7), dist_x.clone()),
(z_id, ParamValue::Float(0.2), dist_z.clone()),
],
0.5,
);
let trial3 = create_trial(
2,
vec![
(x_id, ParamValue::Float(0.6), dist_x.clone()),
(y_id, ParamValue::Float(0.4), dist_y.clone()),
(z_id, ParamValue::Float(0.1), dist_z.clone()),
],
0.8,
);
let result = IntersectionSearchSpace::calculate(&[trial1, trial2, trial3]);
assert_eq!(result.len(), 1);
assert!(result.contains_key(&x_id));
assert!(!result.contains_key(&y_id));
assert!(!result.contains_key(&z_id));
}
#[test]
fn test_no_common_params() {
let x_id = ParamId::new();
let y_id = ParamId::new();
let dist_x = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let dist_y = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let trial1 = create_trial(0, vec![(x_id, ParamValue::Float(0.5), dist_x.clone())], 1.0);
let trial2 = create_trial(1, vec![(y_id, ParamValue::Float(0.3), dist_y.clone())], 0.5);
let result = IntersectionSearchSpace::calculate(&[trial1, trial2]);
assert!(result.is_empty());
}
#[test]
fn test_mixed_distribution_types() {
let lr_id = ParamId::new();
let n_layers_id = ParamId::new();
let optimizer_id = ParamId::new();
let dist_float = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let dist_int = Distribution::Int(IntDistribution {
low: 1,
high: 100,
log_scale: false,
step: None,
});
let dist_cat = Distribution::Categorical(CategoricalDistribution { n_choices: 3 });
let trial1 = create_trial(
0,
vec![
(lr_id, ParamValue::Float(0.01), dist_float.clone()),
(n_layers_id, ParamValue::Int(3), dist_int.clone()),
(optimizer_id, ParamValue::Categorical(0), dist_cat.clone()),
],
1.0,
);
let trial2 = create_trial(
1,
vec![
(lr_id, ParamValue::Float(0.001), dist_float.clone()),
(n_layers_id, ParamValue::Int(5), dist_int.clone()),
(optimizer_id, ParamValue::Categorical(1), dist_cat.clone()),
],
0.8,
);
let result = IntersectionSearchSpace::calculate(&[trial1, trial2]);
assert_eq!(result.len(), 3);
assert!(matches!(result.get(&lr_id), Some(Distribution::Float(_))));
assert!(matches!(
result.get(&n_layers_id),
Some(Distribution::Int(_))
));
assert!(matches!(
result.get(&optimizer_id),
Some(Distribution::Categorical(_))
));
}
#[test]
fn test_distribution_from_first_trial() {
let x_id = ParamId::new();
let dist_x_v1 = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let dist_x_v2 = Distribution::Float(FloatDistribution {
low: 0.0,
high: 10.0, log_scale: false,
step: None,
});
let trial1 = create_trial(
0,
vec![(x_id, ParamValue::Float(0.5), dist_x_v1.clone())],
1.0,
);
let trial2 = create_trial(
1,
vec![(x_id, ParamValue::Float(5.0), dist_x_v2.clone())],
0.5,
);
let result = IntersectionSearchSpace::calculate(&[trial1, trial2]);
assert_eq!(result.len(), 1);
assert_eq!(result.get(&x_id), Some(&dist_x_v1));
}
#[test]
fn test_many_trials_with_conditional_params() {
let lr_id = ParamId::new();
let use_dropout_id = ParamId::new();
let dropout_rate_id = ParamId::new();
let dist_lr = Distribution::Float(FloatDistribution {
low: 1e-5,
high: 1e-1,
log_scale: true,
step: None,
});
let dist_dropout = Distribution::Categorical(CategoricalDistribution { n_choices: 2 });
let dist_dropout_rate = Distribution::Float(FloatDistribution {
low: 0.0,
high: 0.5,
log_scale: false,
step: None,
});
let trial1 = create_trial(
0,
vec![
(lr_id, ParamValue::Float(0.01), dist_lr.clone()),
(
use_dropout_id,
ParamValue::Categorical(1),
dist_dropout.clone(),
),
(
dropout_rate_id,
ParamValue::Float(0.2),
dist_dropout_rate.clone(),
),
],
1.0,
);
let trial2 = create_trial(
1,
vec![
(lr_id, ParamValue::Float(0.001), dist_lr.clone()),
(
use_dropout_id,
ParamValue::Categorical(0),
dist_dropout.clone(),
),
],
0.8,
);
let trial3 = create_trial(
2,
vec![
(lr_id, ParamValue::Float(0.005), dist_lr.clone()),
(
use_dropout_id,
ParamValue::Categorical(1),
dist_dropout.clone(),
),
(
dropout_rate_id,
ParamValue::Float(0.3),
dist_dropout_rate.clone(),
),
],
0.9,
);
let result = IntersectionSearchSpace::calculate(&[trial1, trial2, trial3]);
assert_eq!(result.len(), 2);
assert!(result.contains_key(&lr_id));
assert!(result.contains_key(&use_dropout_id));
assert!(!result.contains_key(&dropout_rate_id)); }
#[test]
fn test_group_empty_trials() {
let trials: Vec<CompletedTrial> = vec![];
let groups = GroupDecomposedSearchSpace::calculate(&trials);
assert!(groups.is_empty());
}
#[test]
fn test_group_single_trial_single_param() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let trial = create_trial(0, vec![(x_id, ParamValue::Float(0.5), dist)], 1.0);
let groups = GroupDecomposedSearchSpace::calculate(&[trial]);
assert_eq!(groups.len(), 1);
assert!(groups[0].contains(&x_id));
assert_eq!(groups[0].len(), 1);
}
#[test]
fn test_group_single_trial_multiple_params() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let y_id = ParamId::new();
let z_id = ParamId::new();
let trial = create_trial(
0,
vec![
(x_id, ParamValue::Float(0.5), dist.clone()),
(y_id, ParamValue::Float(0.3), dist.clone()),
(z_id, ParamValue::Float(0.7), dist),
],
1.0,
);
let groups = GroupDecomposedSearchSpace::calculate(&[trial]);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].len(), 3);
assert!(groups[0].contains(&x_id));
assert!(groups[0].contains(&y_id));
assert!(groups[0].contains(&z_id));
}
#[test]
fn test_group_two_independent_groups() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let y_id = ParamId::new();
let a_id = ParamId::new();
let b_id = ParamId::new();
let trial1 = create_trial(
0,
vec![
(x_id, ParamValue::Float(0.1), dist.clone()),
(y_id, ParamValue::Float(0.2), dist.clone()),
],
1.0,
);
let trial2 = create_trial(
1,
vec![
(a_id, ParamValue::Float(0.3), dist.clone()),
(b_id, ParamValue::Float(0.4), dist),
],
0.5,
);
let groups = GroupDecomposedSearchSpace::calculate(&[trial1, trial2]);
assert_eq!(groups.len(), 2);
let group_xy = groups.iter().find(|g| g.contains(&x_id));
let group_ab = groups.iter().find(|g| g.contains(&a_id));
assert!(group_xy.is_some());
assert!(group_ab.is_some());
let group_xy = group_xy.expect("group with x should exist");
let group_ab = group_ab.expect("group with a should exist");
assert_eq!(group_xy.len(), 2);
assert!(group_xy.contains(&x_id));
assert!(group_xy.contains(&y_id));
assert_eq!(group_ab.len(), 2);
assert!(group_ab.contains(&a_id));
assert!(group_ab.contains(&b_id));
}
#[test]
fn test_group_transitive_connection() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let y_id = ParamId::new();
let z_id = ParamId::new();
let trial1 = create_trial(
0,
vec![
(x_id, ParamValue::Float(0.1), dist.clone()),
(y_id, ParamValue::Float(0.2), dist.clone()),
],
1.0,
);
let trial2 = create_trial(
1,
vec![
(y_id, ParamValue::Float(0.3), dist.clone()),
(z_id, ParamValue::Float(0.4), dist),
],
0.5,
);
let groups = GroupDecomposedSearchSpace::calculate(&[trial1, trial2]);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].len(), 3);
assert!(groups[0].contains(&x_id));
assert!(groups[0].contains(&y_id));
assert!(groups[0].contains(&z_id));
}
#[test]
fn test_group_chain_connection() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let a_id = ParamId::new();
let b_id = ParamId::new();
let c_id = ParamId::new();
let d_id = ParamId::new();
let trial1 = create_trial(
0,
vec![
(a_id, ParamValue::Float(0.1), dist.clone()),
(b_id, ParamValue::Float(0.2), dist.clone()),
],
1.0,
);
let trial2 = create_trial(
1,
vec![
(b_id, ParamValue::Float(0.3), dist.clone()),
(c_id, ParamValue::Float(0.4), dist.clone()),
],
0.5,
);
let trial3 = create_trial(
2,
vec![
(c_id, ParamValue::Float(0.5), dist.clone()),
(d_id, ParamValue::Float(0.6), dist),
],
0.3,
);
let groups = GroupDecomposedSearchSpace::calculate(&[trial1, trial2, trial3]);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].len(), 4);
assert!(groups[0].contains(&a_id));
assert!(groups[0].contains(&b_id));
assert!(groups[0].contains(&c_id));
assert!(groups[0].contains(&d_id));
}
#[test]
fn test_group_multiple_isolated_params() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let y_id = ParamId::new();
let z_id = ParamId::new();
let trial1 = create_trial(0, vec![(x_id, ParamValue::Float(0.1), dist.clone())], 1.0);
let trial2 = create_trial(1, vec![(y_id, ParamValue::Float(0.2), dist.clone())], 0.5);
let trial3 = create_trial(2, vec![(z_id, ParamValue::Float(0.3), dist)], 0.3);
let groups = GroupDecomposedSearchSpace::calculate(&[trial1, trial2, trial3]);
assert_eq!(groups.len(), 3);
for group in &groups {
assert_eq!(group.len(), 1);
}
let all_params: HashSet<ParamId> = groups.iter().flatten().copied().collect();
assert!(all_params.contains(&x_id));
assert!(all_params.contains(&y_id));
assert!(all_params.contains(&z_id));
}
#[test]
fn test_group_complex_scenario() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let x_id = ParamId::new();
let y_id = ParamId::new();
let z_id = ParamId::new();
let a_id = ParamId::new();
let b_id = ParamId::new();
let w_id = ParamId::new();
let trial1 = create_trial(
0,
vec![
(x_id, ParamValue::Float(0.1), dist.clone()),
(y_id, ParamValue::Float(0.2), dist.clone()),
],
1.0,
);
let trial2 = create_trial(
1,
vec![
(y_id, ParamValue::Float(0.3), dist.clone()),
(z_id, ParamValue::Float(0.4), dist.clone()),
],
0.5,
);
let trial3 = create_trial(
2,
vec![
(a_id, ParamValue::Float(0.5), dist.clone()),
(b_id, ParamValue::Float(0.6), dist.clone()),
],
0.3,
);
let trial4 = create_trial(3, vec![(w_id, ParamValue::Float(0.7), dist)], 0.2);
let groups = GroupDecomposedSearchSpace::calculate(&[trial1, trial2, trial3, trial4]);
assert_eq!(groups.len(), 3);
let group_xyz = groups.iter().find(|g| g.contains(&x_id));
let group_ab = groups.iter().find(|g| g.contains(&a_id));
let group_w = groups.iter().find(|g| g.contains(&w_id));
assert!(group_xyz.is_some());
assert!(group_ab.is_some());
assert!(group_w.is_some());
let group_xyz = group_xyz.expect("group with x should exist");
assert_eq!(group_xyz.len(), 3);
assert!(group_xyz.contains(&x_id));
assert!(group_xyz.contains(&y_id));
assert!(group_xyz.contains(&z_id));
let group_ab = group_ab.expect("group with a should exist");
assert_eq!(group_ab.len(), 2);
assert!(group_ab.contains(&a_id));
assert!(group_ab.contains(&b_id));
let group_w = group_w.expect("group with w should exist");
assert_eq!(group_w.len(), 1);
assert!(group_w.contains(&w_id));
}
#[test]
fn test_group_all_params_same_trial() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let a_id = ParamId::new();
let b_id = ParamId::new();
let c_id = ParamId::new();
let d_id = ParamId::new();
let trial = create_trial(
0,
vec![
(a_id, ParamValue::Float(0.1), dist.clone()),
(b_id, ParamValue::Float(0.2), dist.clone()),
(c_id, ParamValue::Float(0.3), dist.clone()),
(d_id, ParamValue::Float(0.4), dist),
],
1.0,
);
let groups = GroupDecomposedSearchSpace::calculate(&[trial]);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].len(), 4);
}
#[test]
fn test_group_with_mixed_distribution_types() {
let lr_id = ParamId::new();
let n_layers_id = ParamId::new();
let optimizer_id = ParamId::new();
let dist_float = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let dist_int = Distribution::Int(IntDistribution {
low: 1,
high: 10,
log_scale: false,
step: None,
});
let dist_cat = Distribution::Categorical(CategoricalDistribution { n_choices: 3 });
let trial1 = create_trial(
0,
vec![
(lr_id, ParamValue::Float(0.01), dist_float.clone()),
(n_layers_id, ParamValue::Int(3), dist_int.clone()),
],
1.0,
);
let trial2 = create_trial(
1,
vec![(optimizer_id, ParamValue::Categorical(1), dist_cat)],
0.5,
);
let trial3 = create_trial(
2,
vec![
(lr_id, ParamValue::Float(0.001), dist_float),
(n_layers_id, ParamValue::Int(5), dist_int),
],
0.8,
);
let groups = GroupDecomposedSearchSpace::calculate(&[trial1, trial2, trial3]);
assert_eq!(groups.len(), 2);
let group_lr = groups.iter().find(|g| g.contains(&lr_id));
let group_opt = groups.iter().find(|g| g.contains(&optimizer_id));
assert!(group_lr.is_some());
assert!(group_opt.is_some());
let group_lr = group_lr.expect("group with learning_rate should exist");
assert_eq!(group_lr.len(), 2);
assert!(group_lr.contains(&lr_id));
assert!(group_lr.contains(&n_layers_id));
let group_opt = group_opt.expect("group with optimizer should exist");
assert_eq!(group_opt.len(), 1);
assert!(group_opt.contains(&optimizer_id));
}
#[test]
fn test_group_star_topology() {
let dist = Distribution::Float(FloatDistribution {
low: 0.0,
high: 1.0,
log_scale: false,
step: None,
});
let center_id = ParamId::new();
let a_id = ParamId::new();
let b_id = ParamId::new();
let c_id = ParamId::new();
let trial1 = create_trial(
0,
vec![
(center_id, ParamValue::Float(0.1), dist.clone()),
(a_id, ParamValue::Float(0.2), dist.clone()),
],
1.0,
);
let trial2 = create_trial(
1,
vec![
(center_id, ParamValue::Float(0.3), dist.clone()),
(b_id, ParamValue::Float(0.4), dist.clone()),
],
0.5,
);
let trial3 = create_trial(
2,
vec![
(center_id, ParamValue::Float(0.5), dist.clone()),
(c_id, ParamValue::Float(0.6), dist),
],
0.3,
);
let groups = GroupDecomposedSearchSpace::calculate(&[trial1, trial2, trial3]);
assert_eq!(groups.len(), 1);
assert_eq!(groups[0].len(), 4);
assert!(groups[0].contains(¢er_id));
assert!(groups[0].contains(&a_id));
assert!(groups[0].contains(&b_id));
assert!(groups[0].contains(&c_id));
}
}