use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use crate::stream::collection_extract::ChangeSource;
use crate::stream::collector::{Accumulator, UniCollector};
pub struct ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
C: UniCollector<A>,
Sc: Score,
{
pub(super) constraint_ref: ConstraintRef,
pub(super) impact_type: ImpactType,
pub(super) extractor_a: EA,
pub(super) extractor_b: EB,
pub(super) key_a: KA,
pub(super) key_b: KB,
pub(super) collector: C,
pub(super) default_fn: D,
pub(super) weight_fn: W,
pub(super) is_hard: bool,
pub(super) a_source: ChangeSource,
pub(super) b_source: ChangeSource,
pub(super) groups: HashMap<K, C::Accumulator>,
pub(super) entity_groups: HashMap<usize, K>,
pub(super) entity_values: HashMap<usize, C::Value>,
pub(super) b_by_key: HashMap<K, usize>,
pub(super) b_index_to_key: HashMap<usize, K>,
pub(super) _phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> Sc)>,
}
impl<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
ComplementedGroupConstraint<S, A, B, K, EA, EB, KA, KB, C, D, W, Sc>
where
S: 'static,
A: Clone + 'static,
B: Clone + 'static,
K: Clone + Eq + Hash,
EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
KA: Fn(&A) -> Option<K>,
KB: Fn(&B) -> K,
C: UniCollector<A>,
C::Result: Clone,
D: Fn(&B) -> C::Result,
W: Fn(&C::Result) -> Sc,
Sc: Score,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
collector: C,
default_fn: D,
weight_fn: W,
is_hard: bool,
) -> Self {
let a_source = extractor_a.change_source();
let b_source = extractor_b.change_source();
Self {
constraint_ref,
impact_type,
extractor_a,
extractor_b,
key_a,
key_b,
collector,
default_fn,
weight_fn,
is_hard,
a_source,
b_source,
groups: HashMap::new(),
entity_groups: HashMap::new(),
entity_values: HashMap::new(),
b_by_key: HashMap::new(),
b_index_to_key: HashMap::new(),
_phantom: PhantomData,
}
}
#[inline]
pub(super) fn compute_score(&self, result: &C::Result) -> Sc {
let base = (self.weight_fn)(result);
match self.impact_type {
ImpactType::Penalty => -base,
ImpactType::Reward => base,
}
}
pub(super) fn build_groups(&self, entities_a: &[A]) -> HashMap<K, C::Result> {
let mut accumulators: HashMap<K, C::Accumulator> = HashMap::new();
for a in entities_a {
let Some(key) = (self.key_a)(a) else {
continue;
};
let value = self.collector.extract(a);
accumulators
.entry(key)
.or_insert_with(|| self.collector.create_accumulator())
.accumulate(&value);
}
accumulators
.into_iter()
.map(|(k, acc)| (k, acc.finish()))
.collect()
}
}