use crate::grid::coordinate_system::CoordinateSystem;
use rand::{rngs::StdRng, Rng};
use crate::NodeIndex;
use super::rules::Rules;
#[derive(Copy, Clone, Debug)]
pub enum NodeSelectionHeuristic {
MinimumRemainingValue,
MinimumEntropy,
Random,
}
const MAX_NOISE_VALUE: f32 = 1E-2;
pub(crate) enum InternalNodeSelectionHeuristic {
MinimumRemainingValue,
MinimumEntropy {
initial_node_entropy_data: NodeEntropyData,
node_entropies: Vec<NodeEntropyData>,
models_weight_log_weights: Vec<f32>,
},
Random,
}
#[derive(Clone, Copy)]
pub(crate) struct NodeEntropyData {
entropy: f32,
weight_sum: f32,
weight_log_weight_sum: f32,
}
impl NodeEntropyData {
fn new(weight_sum: f32, weight_log_weight_sum: f32) -> Self {
Self {
entropy: entropy(weight_sum, weight_log_weight_sum),
weight_sum,
weight_log_weight_sum,
}
}
pub(crate) fn entropy(&self) -> f32 {
self.entropy
}
}
fn entropy(weight_sum: f32, weight_log_weight_sum: f32) -> f32 {
f32::ln(weight_sum) - weight_log_weight_sum / weight_sum
}
impl InternalNodeSelectionHeuristic {
pub(crate) fn from_external<T: CoordinateSystem + Clone>(
heuristic: NodeSelectionHeuristic,
rules: &Rules<T>,
node_count: usize,
) -> Self {
match heuristic {
NodeSelectionHeuristic::MinimumRemainingValue => {
InternalNodeSelectionHeuristic::MinimumRemainingValue
}
NodeSelectionHeuristic::Random => InternalNodeSelectionHeuristic::Random,
NodeSelectionHeuristic::MinimumEntropy => {
InternalNodeSelectionHeuristic::new_minimum_entropy(rules, node_count)
}
}
}
fn new_minimum_entropy<T: CoordinateSystem + Clone>(
rules: &Rules<T>,
node_count: usize,
) -> InternalNodeSelectionHeuristic {
let mut models_weight_log_weights = Vec::with_capacity(rules.models_count());
let mut all_models_weight_sum = 0.;
let mut all_models_weight_log_weight_sum = 0.;
for model_index in 0..rules.models_count() {
let weight = rules.weight_unchecked(model_index);
let weight_log_weight = weight * f32::ln(weight);
models_weight_log_weights.push(weight_log_weight);
all_models_weight_sum += weight;
all_models_weight_log_weight_sum += weight_log_weight;
}
let initial_node_entropy_data =
NodeEntropyData::new(all_models_weight_sum, all_models_weight_log_weight_sum);
InternalNodeSelectionHeuristic::MinimumEntropy {
initial_node_entropy_data,
node_entropies: vec![initial_node_entropy_data; node_count],
models_weight_log_weights,
}
}
pub(crate) fn reinitialize(&mut self) {
match self {
InternalNodeSelectionHeuristic::MinimumEntropy {
initial_node_entropy_data,
node_entropies,
models_weight_log_weights: _,
} => {
for node_entropy in node_entropies {
*node_entropy = *initial_node_entropy_data;
}
}
_ => (),
}
}
pub(crate) fn handle_ban(&mut self, node_index: NodeIndex, model_index: usize, weight: f32) {
match self {
InternalNodeSelectionHeuristic::MinimumEntropy {
initial_node_entropy_data: _,
node_entropies,
models_weight_log_weights,
} => {
let node_entropy = &mut node_entropies[node_index];
node_entropy.weight_sum -= weight;
node_entropy.weight_log_weight_sum -= models_weight_log_weights[model_index];
node_entropy.entropy =
entropy(node_entropy.weight_sum, node_entropy.weight_log_weight_sum)
}
_ => (),
}
}
pub(crate) fn select_node(
&self,
possible_models_counts: &Vec<usize>,
rng: &mut StdRng,
) -> Option<NodeIndex> {
match self {
InternalNodeSelectionHeuristic::MinimumRemainingValue => {
let mut min = f32::MAX;
let mut picked_node = None;
for (index, &possibilities_count) in possible_models_counts.iter().enumerate() {
if possibilities_count > 1 {
let noise = MAX_NOISE_VALUE * rng.gen::<f32>();
if (possibilities_count as f32 + noise) < min {
min = possibilities_count as f32 + noise;
picked_node = Some(index);
}
}
}
picked_node
}
InternalNodeSelectionHeuristic::MinimumEntropy {
initial_node_entropy_data: _,
node_entropies,
models_weight_log_weights: _,
} => {
let mut min = f32::MAX;
let mut picked_node = None;
for (index, &possibilities_count) in possible_models_counts.iter().enumerate() {
let entropy = node_entropies[index].entropy();
if possibilities_count > 1 && entropy < min {
let noise = MAX_NOISE_VALUE * rng.gen::<f32>();
if (entropy + noise) < min {
min = entropy + noise;
picked_node = Some(index);
}
}
}
picked_node
}
InternalNodeSelectionHeuristic::Random => {
let mut picked_node = None;
let mut candidates = Vec::new();
for (index, &possibilities_count) in possible_models_counts.iter().enumerate() {
if possibilities_count > 1 {
candidates.push(index);
}
}
if candidates.len() > 0 {
picked_node = Some(candidates[rng.gen_range(0..candidates.len())]);
}
picked_node
}
}
}
}