use std::hash::Hash;
use crate::stream::collector::{Accumulator, UniCollector};
use solverforge_core::score::Score;
use solverforge_core::ImpactType;
use super::ComplementedGroupConstraint;
impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
K: Clone + Eq + Hash + Send + Sync,
EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
KA: Fn(&A) -> Option<K> + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
C: UniCollector<A> + Send + Sync,
C::Accumulator: Send + Sync,
C::Result: Clone + Send + Sync,
C::Value: Send + Sync,
D: Fn(&B) -> C::Result + Send + Sync,
W: Fn(&C::Result) -> Sc + Send + Sync,
Sc: Score,
{
pub(super) fn insert_entity(
&mut self,
entities_b: &[B],
entity_index: usize,
entity: &A,
) -> Sc {
let Some(key) = (self.key_a)(entity) else {
return Sc::zero();
};
let value = self.collector.extract(entity);
let impact = self.impact_type;
let b_idx = self.b_by_key.get(&key).copied();
let Some(b_idx) = b_idx else {
let acc = self
.groups
.entry(key.clone())
.or_insert_with(|| self.collector.create_accumulator());
acc.accumulate(&value);
self.entity_groups.insert(entity_index, key);
self.entity_values.insert(entity_index, value);
return Sc::zero();
};
let b = &entities_b[b_idx];
let old_result = self
.groups
.get(&key)
.map(|acc| acc.finish())
.unwrap_or_else(|| (self.default_fn)(b));
let old_base = (self.weight_fn)(&old_result);
let old = match impact {
ImpactType::Penalty => -old_base,
ImpactType::Reward => old_base,
};
let acc = self
.groups
.entry(key.clone())
.or_insert_with(|| self.collector.create_accumulator());
acc.accumulate(&value);
let new_result = acc.finish();
let new_base = (self.weight_fn)(&new_result);
let new_score = match impact {
ImpactType::Penalty => -new_base,
ImpactType::Reward => new_base,
};
self.entity_groups.insert(entity_index, key);
self.entity_values.insert(entity_index, value);
new_score - old
}
fn b_score_for_key(&self, entities_b: &[B], key: &K, b_idx: usize) -> Sc {
if b_idx >= entities_b.len() {
return Sc::zero();
}
let b = &entities_b[b_idx];
let result = self
.groups
.get(key)
.map(|acc| acc.finish())
.unwrap_or_else(|| (self.default_fn)(b));
self.compute_score(&result)
}
pub(super) fn insert_b(&mut self, entities_b: &[B], b_idx: usize) -> Sc {
if b_idx >= entities_b.len() {
return Sc::zero();
}
let key = (self.key_b)(&entities_b[b_idx]);
self.b_by_key.insert(key.clone(), b_idx);
self.b_index_to_key.insert(b_idx, key.clone());
self.b_score_for_key(entities_b, &key, b_idx)
}
pub(super) fn retract_b(&mut self, entities_b: &[B], b_idx: usize) -> Sc {
let Some(key) = self.b_index_to_key.remove(&b_idx) else {
return Sc::zero();
};
self.b_by_key.remove(&key);
-self.b_score_for_key(entities_b, &key, b_idx)
}
pub(super) fn retract_entity(
&mut self,
_entities_a: &[A],
_entities_b: &[B],
entity_index: usize,
) -> Sc {
let Some(key) = self.entity_groups.remove(&entity_index) else {
return Sc::zero();
};
let Some(value) = self.entity_values.remove(&entity_index) else {
return Sc::zero();
};
let impact = self.impact_type;
let b_idx = self.b_by_key.get(&key).copied();
if b_idx.is_none() {
if let Some(acc) = self.groups.get_mut(&key) {
acc.retract(&value);
}
return Sc::zero();
}
let Some(acc) = self.groups.get_mut(&key) else {
return Sc::zero();
};
let old_result = acc.finish();
let old_base = (self.weight_fn)(&old_result);
let old = match impact {
ImpactType::Penalty => -old_base,
ImpactType::Reward => old_base,
};
acc.retract(&value);
let new_result = acc.finish();
let new_base = (self.weight_fn)(&new_result);
let new_score = match impact {
ImpactType::Penalty => -new_base,
ImpactType::Reward => new_base,
};
new_score - old
}
}