use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use crate::api::constraint_set::IncrementalConstraint;
use crate::stream::collector::{Accumulator, UniCollector};
pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
C: UniCollector<A>,
Sc: Score,
{
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
collector: C,
default_fn: D,
weight_fn: W,
is_hard: bool,
groups: HashMap<K, C::Accumulator>,
entity_groups: HashMap<usize, K>,
entity_values: HashMap<usize, C::Value>,
b_by_key: HashMap<K, usize>,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
}
impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
S: 'static,
A: Clone + 'static,
B: Clone + 'static,
K: Clone + Eq + Hash,
EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
KA: Fn(&A) -> Option<K>,
KB: Fn(&B) -> K,
C: UniCollector<A>,
C::Result: Clone,
D: Fn(&B) -> C::Result,
W: Fn(&C::Result) -> Sc,
Sc: Score,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
collector: C,
default_fn: D,
weight_fn: W,
is_hard: bool,
) -> Self {
Self {
constraint_ref,
impact_type,
extractor_a,
extractor_b,
key_a,
key_b,
collector,
default_fn,
weight_fn,
is_hard,
groups: HashMap::new(),
entity_groups: HashMap::new(),
entity_values: HashMap::new(),
b_by_key: HashMap::new(),
_phantom: PhantomData,
}
}
#[inline]
fn compute_score(&self, result: &C::Result) -> Sc {
let base = (self.weight_fn)(result);
match self.impact_type {
ImpactType::Penalty => -base,
ImpactType::Reward => base,
}
}
fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Result> {
let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
for a in entities_a {
let Some(key) = (self.key_a)(a) else {
continue;
};
let value = self.collector.extract(a);
accumulators
.entry(key)
.or_insert_with(|| self.collector.create_accumulator())
.accumulate(&value);
}
accumulators
.into_iter()
.map(|(k, acc)| (k, acc.finish()))
.collect()
}
}
impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> IncrementalConstraint<S, Sc>
for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync,
EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
KA: Fn(&A) -> Option<K> + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
C: UniCollector<A> + Send + Sync,
C::Accumulator: Send + Sync,
C::Result: Clone + Send + Sync,
C::Value: Send + Sync,
D: Fn(&B) -> C::Result + Send + Sync,
W: Fn(&C::Result) -> Sc + Send + Sync,
Sc: Score,
{
fn evaluate(&self, solution: &S) -> Sc {
let entities_a = self.extractor_a.extract(solution);
let entities_b = self.extractor_b.extract(solution);
let groups = self.build_groups(entities_a);
let mut total = Sc::zero();
for b in entities_b {
let key = (self.key_b)(b);
let result = groups
.get(&key)
.cloned()
.unwrap_or_else(|| (self.default_fn)(b));
total = total + self.compute_score(&result);
}
total
}
fn match_count(&self, solution: &S) -> usize {
let entities_b = self.extractor_b.extract(solution);
entities_b.len()
}
fn initialize(&mut self, solution: &S) -> Sc {
self.reset();
let entities_a = self.extractor_a.extract(solution);
let entities_b = self.extractor_b.extract(solution);
for (idx, b) in entities_b.iter().enumerate() {
let key = (self.key_b)(b);
self.b_by_key.insert(key, idx);
}
let mut total = Sc::zero();
for b in entities_b {
let default_result = (self.default_fn)(b);
total = total + self.compute_score(&default_result);
}
for (idx, a) in entities_a.iter().enumerate() {
total = total + self.insert_entity(entities_b, idx, a);
}
total
}
fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
let entities_a = self.extractor_a.extract(solution);
let entities_b = self.extractor_b.extract(solution);
if entity_index >= entities_a.len() {
return Sc::zero();
}
let entity = &entities_a[entity_index];
self.insert_entity(entities_b, entity_index, entity)
}
fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
let entities_a = self.extractor_a.extract(solution);
let entities_b = self.extractor_b.extract(solution);
self.retract_entity(entities_a, entities_b, entity_index)
}
fn reset(&mut self) {
self.groups.clear();
self.entity_groups.clear();
self.entity_values.clear();
self.b_by_key.clear();
}
fn name(&self) -> &str {
&self.constraint_ref.name
}
fn is_hard(&self) -> bool {
self.is_hard
}
fn constraint_ref(&self) -> ConstraintRef {
self.constraint_ref.clone()
}
}
impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync,
EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
KA: Fn(&A) -> Option<K> + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
C: UniCollector<A> + Send + Sync,
C::Accumulator: Send + Sync,
C::Result: Clone + Send + Sync,
C::Value: Send + Sync,
D: Fn(&B) -> C::Result + Send + Sync,
W: Fn(&C::Result) -> Sc + Send + Sync,
Sc: Score,
{
fn insert_entity(&mut self, entities_b: &[B], entity_index: usize, entity: &A) -> Sc {
let Some(key) = (self.key_a)(entity) else {
return Sc::zero();
};
let value = self.collector.extract(entity);
let impact = self.impact_type;
let b_idx = self.b_by_key.get(&key).copied();
let Some(b_idx) = b_idx else {
let acc = self
.groups
.entry(key.clone())
.or_insert_with(|| self.collector.create_accumulator());
acc.accumulate(&value);
self.entity_groups.insert(entity_index, key);
self.entity_values.insert(entity_index, value);
return Sc::zero();
};
let b = &entities_b[b_idx];
let old_result = self
.groups
.get(&key)
.map(|acc| acc.finish())
.unwrap_or_else(|| (self.default_fn)(b));
let old_base = (self.weight_fn)(&old_result);
let old = match impact {
ImpactType::Penalty => -old_base,
ImpactType::Reward => old_base,
};
let acc = self
.groups
.entry(key.clone())
.or_insert_with(|| self.collector.create_accumulator());
acc.accumulate(&value);
let new_result = acc.finish();
let new_base = (self.weight_fn)(&new_result);
let new_score = match impact {
ImpactType::Penalty => -new_base,
ImpactType::Reward => new_base,
};
self.entity_groups.insert(entity_index, key);
self.entity_values.insert(entity_index, value);
new_score - old
}
fn retract_entity(&mut self, _entities_a: &[A], _entities_b: &[B], entity_index: usize) -> Sc {
let Some(key) = self.entity_groups.remove(&entity_index) else {
return Sc::zero();
};
let Some(value) = self.entity_values.remove(&entity_index) else {
return Sc::zero();
};
let impact = self.impact_type;
let b_idx = self.b_by_key.get(&key).copied();
if b_idx.is_none() {
if let Some(acc) = self.groups.get_mut(&key) {
acc.retract(&value);
}
return Sc::zero();
}
let Some(acc) = self.groups.get_mut(&key) else {
return Sc::zero();
};
let old_result = acc.finish();
let old_base = (self.weight_fn)(&old_result);
let old = match impact {
ImpactType::Penalty => -old_base,
ImpactType::Reward => old_base,
};
acc.retract(&value);
let new_result = acc.finish();
let new_base = (self.weight_fn)(&new_result);
let new_score = match impact {
ImpactType::Penalty => -new_base,
ImpactType::Reward => new_base,
};
new_score - old
}
}
impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc> std::fmt::Debug
for ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
C: UniCollector<A>,
Sc: Score,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ComplementedGroupConstraint")
.field("name", &self.constraint_ref.name)
.field("impact_type", &self.impact_type)
.field("groups", &self.groups.len())
.finish()
}
}