use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use crate::api::constraint_set::IncrementalConstraint;
use crate::stream::filter::UniFilter;
pub struct BalanceConstraint<S, A, K, E, F, KF, Sc>
where
Sc: Score,
{
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor: E,
filter: F,
key_fn: KF,
base_score: Sc,
is_hard: bool,
counts: HashMap<K, i64>,
entity_keys: HashMap<usize, K>,
group_count: i64,
total_count: i64,
sum_squared: i64,
_phantom: PhantomData<(fn() -> S, fn() -> A)>,
}
impl<S, A, K, E, F, KF, Sc> BalanceConstraint<S, A, K, E, F, KF, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync + 'static,
E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
F: UniFilter<S, A>,
KF: Fn(&A) -> Option<K> + Send + Sync,
Sc: Score + 'static,
{
pub fn new(
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor: E,
filter: F,
key_fn: KF,
base_score: Sc,
is_hard: bool,
) -> Self {
Self {
constraint_ref,
impact_type,
extractor,
filter,
key_fn,
base_score,
is_hard,
counts: HashMap::new(),
entity_keys: HashMap::new(),
group_count: 0,
total_count: 0,
sum_squared: 0,
_phantom: PhantomData,
}
}
fn compute_std_dev(&self) -> f64 {
if self.group_count == 0 {
return 0.0;
}
let n = self.group_count as f64;
let mean = self.total_count as f64 / n;
let variance = (self.sum_squared as f64 / n) - (mean * mean);
if variance <= 0.0 {
return 0.0;
}
variance.sqrt()
}
fn compute_score(&self) -> Sc {
let std_dev = self.compute_std_dev();
let base = self.base_score.multiply(std_dev);
match self.impact_type {
ImpactType::Penalty => -base,
ImpactType::Reward => base,
}
}
fn compute_std_dev_from_counts(counts: &HashMap<K, i64>) -> f64 {
if counts.is_empty() {
return 0.0;
}
let n = counts.len() as f64;
let total: i64 = counts.values().sum();
let sum_sq: i64 = counts.values().map(|&c| c * c).sum();
let mean = total as f64 / n;
let variance = (sum_sq as f64 / n) - (mean * mean);
if variance > 0.0 {
variance.sqrt()
} else {
0.0
}
}
}
impl<S, A, K, E, F, KF, Sc> IncrementalConstraint<S, Sc>
for BalanceConstraint<S, A, K, E, F, KF, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync + 'static,
E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
F: UniFilter<S, A>,
KF: Fn(&A) -> Option<K> + Send + Sync,
Sc: Score + 'static,
{
fn evaluate(&self, solution: &S) -> Sc {
let entities = self.extractor.extract(solution);
let mut counts: HashMap<K, i64> = HashMap::new();
for entity in entities {
if !self.filter.test(solution, entity) {
continue;
}
if let Some(key) = (self.key_fn)(entity) {
*counts.entry(key).or_insert(0) += 1;
}
}
if counts.is_empty() {
return Sc::zero();
}
let std_dev = Self::compute_std_dev_from_counts(&counts);
let base = self.base_score.multiply(std_dev);
match self.impact_type {
ImpactType::Penalty => -base,
ImpactType::Reward => base,
}
}
fn match_count(&self, solution: &S) -> usize {
let entities = self.extractor.extract(solution);
let mut counts: HashMap<K, i64> = HashMap::new();
for entity in entities {
if !self.filter.test(solution, entity) {
continue;
}
if let Some(key) = (self.key_fn)(entity) {
*counts.entry(key).or_insert(0) += 1;
}
}
if counts.is_empty() {
return 0;
}
let total: i64 = counts.values().sum();
let mean = total as f64 / counts.len() as f64;
counts
.values()
.filter(|&&c| (c as f64 - mean).abs() > 0.5)
.count()
}
fn initialize(&mut self, solution: &S) -> Sc {
self.reset();
let entities = self.extractor.extract(solution);
for (idx, entity) in entities.iter().enumerate() {
if !self.filter.test(solution, entity) {
continue;
}
if let Some(key) = (self.key_fn)(entity) {
let old_count = *self.counts.get(&key).unwrap_or(&0);
let new_count = old_count + 1;
self.counts.insert(key.clone(), new_count);
self.entity_keys.insert(idx, key);
if old_count == 0 {
self.group_count += 1;
}
self.total_count += 1;
self.sum_squared += new_count * new_count - old_count * old_count;
}
}
self.compute_score()
}
fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
let entities = self.extractor.extract(solution);
if entity_index >= entities.len() {
return Sc::zero();
}
let entity = &entities[entity_index];
if !self.filter.test(solution, entity) {
return Sc::zero();
}
let Some(key) = (self.key_fn)(entity) else {
return Sc::zero();
};
let old_score = self.compute_score();
let old_count = *self.counts.get(&key).unwrap_or(&0);
let new_count = old_count + 1;
self.counts.insert(key.clone(), new_count);
self.entity_keys.insert(entity_index, key);
if old_count == 0 {
self.group_count += 1;
}
self.total_count += 1;
self.sum_squared += new_count * new_count - old_count * old_count;
let new_score = self.compute_score();
new_score - old_score
}
fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
let entities = self.extractor.extract(solution);
if entity_index >= entities.len() {
return Sc::zero();
}
let Some(key) = self.entity_keys.remove(&entity_index) else {
return Sc::zero();
};
let old_score = self.compute_score();
let old_count = *self.counts.get(&key).unwrap_or(&0);
if old_count == 0 {
return Sc::zero();
}
let new_count = old_count - 1;
if new_count == 0 {
self.counts.remove(&key);
self.group_count -= 1;
} else {
self.counts.insert(key, new_count);
}
self.total_count -= 1;
self.sum_squared += new_count * new_count - old_count * old_count;
let new_score = self.compute_score();
new_score - old_score
}
fn reset(&mut self) {
self.counts.clear();
self.entity_keys.clear();
self.group_count = 0;
self.total_count = 0;
self.sum_squared = 0;
}
fn name(&self) -> &str {
&self.constraint_ref.name
}
fn is_hard(&self) -> bool {
self.is_hard
}
fn constraint_ref(&self) -> ConstraintRef {
self.constraint_ref.clone()
}
}
impl<S, A, K, E, F, KF, Sc> std::fmt::Debug for BalanceConstraint<S, A, K, E, F, KF, Sc>
where
Sc: Score,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BalanceConstraint")
.field("name", &self.constraint_ref.name)
.field("impact_type", &self.impact_type)
.field("groups", &self.counts.len())
.finish()
}
}