use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use crate::api::constraint_set::IncrementalConstraint;
use crate::stream::collector::{Accumulator, UniCollector};
use crate::stream::filter::UniFilter;
struct GroupState<Acc> {
accumulator: Acc,
count: usize,
}
pub struct GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
where
C: UniCollector<A>,
Sc: Score,
{
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor: E,
filter: Fi,
key_fn: KF,
collector: C,
weight_fn: W,
is_hard: bool,
change_source: crate::stream::collection_extract::ChangeSource,
groups: HashMap<K, GroupState<C::Accumulator>>,
entity_groups: HashMap<usize, K>,
entity_values: HashMap<usize, C::Value>,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
}
impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync + 'static,
E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
Fi: UniFilter<S, A>,
KF: Fn(&A) -> K + Send + Sync,
C: UniCollector<A> + Send + Sync + 'static,
C::Accumulator: Send + Sync,
C::Result: Send + Sync,
W: Fn(&K, &C::Result) -> Sc + Send + Sync,
Sc: Score + 'static,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor: E,
filter: Fi,
key_fn: KF,
collector: C,
weight_fn: W,
is_hard: bool,
) -> Self {
let change_source = extractor.change_source();
Self {
constraint_ref,
impact_type,
extractor,
filter,
key_fn,
collector,
weight_fn,
is_hard,
change_source,
groups: HashMap::new(),
entity_groups: HashMap::new(),
entity_values: HashMap::new(),
_phantom: PhantomData,
}
}
fn compute_score(&self, key: &K, result: &C::Result) -> Sc {
let base = (self.weight_fn)(key, result);
match self.impact_type {
ImpactType::Penalty => -base,
ImpactType::Reward => base,
}
}
}
impl<S, A, K, E, Fi, KF, C, W, Sc> IncrementalConstraint<S, Sc>
for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync + 'static,
E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
Fi: UniFilter<S, A>,
KF: Fn(&A) -> K + Send + Sync,
C: UniCollector<A> + Send + Sync + 'static,
C::Accumulator: Send + Sync,
C::Result: Send + Sync,
C::Value: Send + Sync,
W: Fn(&K, &C::Result) -> Sc + Send + Sync,
Sc: Score + 'static,
{
fn evaluate(&self, solution: &S) -> Sc {
let entities = self.extractor.extract(solution);
let mut groups: HashMap<K, C::Accumulator> = HashMap::new();
for entity in entities {
if !self.filter.test(solution, entity) {
continue;
}
let key = (self.key_fn)(entity);
let value = self.collector.extract(entity);
let acc = groups
.entry(key)
.or_insert_with(|| self.collector.create_accumulator());
acc.accumulate(&value);
}
let mut total = Sc::zero();
for (key, acc) in &groups {
let result = acc.finish();
total = total + self.compute_score(key, &result);
}
total
}
fn match_count(&self, solution: &S) -> usize {
let entities = self.extractor.extract(solution);
let mut groups: HashMap<K, ()> = HashMap::new();
for entity in entities {
if !self.filter.test(solution, entity) {
continue;
}
let key = (self.key_fn)(entity);
groups.insert(key, ());
}
groups.len()
}
fn initialize(&mut self, solution: &S) -> Sc {
self.reset();
let entities = self.extractor.extract(solution);
let mut total = Sc::zero();
for (idx, entity) in entities.iter().enumerate() {
if !self.filter.test(solution, entity) {
continue;
}
total = total + self.insert_entity(entities, idx, entity);
}
total
}
fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
if !self
.change_source
.assert_localizes(descriptor_index, &self.constraint_ref.name)
{
return Sc::zero();
}
let entities = self.extractor.extract(solution);
if entity_index >= entities.len() {
return Sc::zero();
}
let entity = &entities[entity_index];
if !self.filter.test(solution, entity) {
return Sc::zero();
}
self.insert_entity(entities, entity_index, entity)
}
fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
if !self
.change_source
.assert_localizes(descriptor_index, &self.constraint_ref.name)
{
return Sc::zero();
}
let entities = self.extractor.extract(solution);
self.retract_entity(entities, entity_index)
}
fn reset(&mut self) {
self.groups.clear();
self.entity_groups.clear();
self.entity_values.clear();
}
fn name(&self) -> &str {
&self.constraint_ref.name
}
fn is_hard(&self) -> bool {
self.is_hard
}
fn constraint_ref(&self) -> &ConstraintRef {
&self.constraint_ref
}
}
impl<S, A, K, E, Fi, KF, C, W, Sc> GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync + 'static,
E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
Fi: UniFilter<S, A>,
KF: Fn(&A) -> K + Send + Sync,
C: UniCollector<A> + Send + Sync + 'static,
C::Accumulator: Send + Sync,
C::Result: Send + Sync,
C::Value: Send + Sync,
W: Fn(&K, &C::Result) -> Sc + Send + Sync,
Sc: Score + 'static,
{
fn insert_entity(&mut self, _entities: &[A], entity_index: usize, entity: &A) -> Sc {
let key = (self.key_fn)(entity);
let entity_key = (self.key_fn)(entity);
let value = self.collector.extract(entity);
let impact = self.impact_type;
let weight_fn = &self.weight_fn;
let (old, new_score) = match self.groups.entry(key) {
std::collections::hash_map::Entry::Occupied(mut entry) => {
let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
let old = match impact {
ImpactType::Penalty => -old_base,
ImpactType::Reward => old_base,
};
let group = entry.get_mut();
group.accumulator.accumulate(&value);
group.count += 1;
let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
let new_score = match impact {
ImpactType::Penalty => -new_base,
ImpactType::Reward => new_base,
};
(old, new_score)
}
std::collections::hash_map::Entry::Vacant(entry) => {
let mut entry = entry.insert_entry(GroupState {
accumulator: self.collector.create_accumulator(),
count: 0,
});
let group = entry.get_mut();
group.accumulator.accumulate(&value);
group.count += 1;
let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
let new_score = match impact {
ImpactType::Penalty => -new_base,
ImpactType::Reward => new_base,
};
(Sc::zero(), new_score)
}
};
self.entity_groups.insert(entity_index, entity_key);
self.entity_values.insert(entity_index, value);
new_score - old
}
fn retract_entity(&mut self, _entities: &[A], entity_index: usize) -> Sc {
let Some(key) = self.entity_groups.remove(&entity_index) else {
return Sc::zero();
};
let Some(value) = self.entity_values.remove(&entity_index) else {
return Sc::zero();
};
let impact = self.impact_type;
let weight_fn = &self.weight_fn;
let std::collections::hash_map::Entry::Occupied(mut entry) = self.groups.entry(key) else {
return Sc::zero();
};
let old_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
let old = match impact {
ImpactType::Penalty => -old_base,
ImpactType::Reward => old_base,
};
let group = entry.get_mut();
group.accumulator.retract(&value);
group.count = group.count.saturating_sub(1);
let is_empty = group.count == 0;
let new_score = if is_empty {
entry.remove();
Sc::zero()
} else {
let new_base = weight_fn(entry.key(), &entry.get().accumulator.finish());
match impact {
ImpactType::Penalty => -new_base,
ImpactType::Reward => new_base,
}
};
new_score - old
}
}
impl<S, A, K, E, Fi, KF, C, W, Sc> std::fmt::Debug
for GroupedUniConstraint<S, A, K, E, Fi, KF, C, W, Sc>
where
C: UniCollector<A>,
Sc: Score,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GroupedUniConstraint")
.field("name", &self.constraint_ref.name)
.field("impact_type", &self.impact_type)
.field("groups", &self.groups.len())
.finish()
}
}