use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use crate::api::constraint_set::IncrementalConstraint;
pub struct FlattenedBiConstraint<
S,
A,
B,
C,
K,
CK,
EA,
EB,
KA,
KB,
Flatten,
CKeyFn,
ALookup,
F,
W,
Sc,
> where
Sc: Score,
{
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
flatten: Flatten,
c_key_fn: CKeyFn,
a_lookup_fn: ALookup,
filter: F,
weight: W,
is_hard: bool,
c_index: HashMap<(K, CK), Vec<(usize, C)>>,
a_scores: HashMap<usize, Sc>,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B)>,
}
impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
where
S: 'static,
A: Clone + 'static,
B: Clone + 'static,
C: Clone + 'static,
K: Eq + Hash + Clone,
CK: Eq + Hash + Clone,
EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
KA: Fn(&A) -> K,
KB: Fn(&B) -> K,
Flatten: Fn(&B) -> &[C],
CKeyFn: Fn(&C) -> CK,
ALookup: Fn(&A) -> CK,
F: Fn(&S, &A, &C) -> bool,
W: Fn(&A, &C) -> Sc,
Sc: Score,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
flatten: Flatten,
c_key_fn: CKeyFn,
a_lookup_fn: ALookup,
filter: F,
weight: W,
is_hard: bool,
) -> Self {
Self {
constraint_ref,
impact_type,
extractor_a,
extractor_b,
key_a,
key_b,
flatten,
c_key_fn,
a_lookup_fn,
filter,
weight,
is_hard,
c_index: HashMap::new(),
a_scores: HashMap::new(),
_phantom: PhantomData,
}
}
#[inline]
fn compute_score(&self, a: &A, c: &C) -> Sc {
let base = (self.weight)(a, c);
match self.impact_type {
ImpactType::Penalty => -base,
ImpactType::Reward => base,
}
}
fn build_c_index(&mut self, entities_b: &[B]) {
self.c_index.clear();
for (b_idx, b) in entities_b.iter().enumerate() {
let join_key = (self.key_b)(b);
for c in (self.flatten)(b) {
let c_key = (self.c_key_fn)(c);
self.c_index
.entry((join_key.clone(), c_key))
.or_default()
.push((b_idx, c.clone()));
}
}
}
fn compute_a_score(&self, solution: &S, a: &A) -> Sc {
let join_key = (self.key_a)(a);
let lookup_key = (self.a_lookup_fn)(a);
let matches = match self.c_index.get(&(join_key, lookup_key)) {
Some(v) => v.as_slice(),
None => return Sc::zero(),
};
let mut total = Sc::zero();
for (_b_idx, c) in matches {
if (self.filter)(solution, a, c) {
total = total + self.compute_score(a, c);
}
}
total
}
fn insert_a(&mut self, solution: &S, entities_a: &[A], a_idx: usize) -> Sc {
if a_idx >= entities_a.len() {
return Sc::zero();
}
let a = &entities_a[a_idx];
let score = self.compute_a_score(solution, a);
if score != Sc::zero() {
self.a_scores.insert(a_idx, score);
}
score
}
fn retract_a(&mut self, a_idx: usize) -> Sc {
match self.a_scores.remove(&a_idx) {
Some(score) => -score,
None => Sc::zero(),
}
}
}
impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
IncrementalConstraint<S, Sc>
for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
C: Clone + Send + Sync + 'static,
K: Eq + Hash + Clone + Send + Sync,
CK: Eq + Hash + Clone + Send + Sync,
EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
EB: crate::stream::collection_extract::CollectionExtract<S, Item = B>,
KA: Fn(&A) -> K + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
Flatten: Fn(&B) -> &[C] + Send + Sync,
CKeyFn: Fn(&C) -> CK + Send + Sync,
ALookup: Fn(&A) -> CK + Send + Sync,
F: Fn(&S, &A, &C) -> bool + Send + Sync,
W: Fn(&A, &C) -> Sc + Send + Sync,
Sc: Score,
{
fn evaluate(&self, solution: &S) -> Sc {
let entities_a = self.extractor_a.extract(solution);
let entities_b = self.extractor_b.extract(solution);
let mut total = Sc::zero();
let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
for (b_idx, b) in entities_b.iter().enumerate() {
let join_key = (self.key_b)(b);
for c in (self.flatten)(b) {
let c_key = (self.c_key_fn)(c);
temp_index
.entry((join_key.clone(), c_key))
.or_default()
.push((b_idx, c.clone()));
}
}
for a in entities_a {
let join_key = (self.key_a)(a);
let lookup_key = (self.a_lookup_fn)(a);
if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
for (_b_idx, c) in matches {
if (self.filter)(solution, a, c) {
total = total + self.compute_score(a, c);
}
}
}
}
total
}
fn match_count(&self, solution: &S) -> usize {
let entities_a = self.extractor_a.extract(solution);
let entities_b = self.extractor_b.extract(solution);
let mut count = 0;
let mut temp_index: HashMap<(K, CK), Vec<(usize, C)>> = HashMap::new();
for (b_idx, b) in entities_b.iter().enumerate() {
let join_key = (self.key_b)(b);
for c in (self.flatten)(b) {
let c_key = (self.c_key_fn)(c);
temp_index
.entry((join_key.clone(), c_key))
.or_default()
.push((b_idx, c.clone()));
}
}
for a in entities_a {
let join_key = (self.key_a)(a);
let lookup_key = (self.a_lookup_fn)(a);
if let Some(matches) = temp_index.get(&(join_key, lookup_key)) {
for (_b_idx, c) in matches {
if (self.filter)(solution, a, c) {
count += 1;
}
}
}
}
count
}
fn initialize(&mut self, solution: &S) -> Sc {
self.reset();
let entities_a = self.extractor_a.extract(solution);
let entities_b = self.extractor_b.extract(solution);
self.build_c_index(entities_b);
let mut total = Sc::zero();
for a_idx in 0..entities_a.len() {
total = total + self.insert_a(solution, entities_a, a_idx);
}
total
}
fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
let entities_a = self.extractor_a.extract(solution);
self.insert_a(solution, entities_a, entity_index)
}
fn on_retract(&mut self, _solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
self.retract_a(entity_index)
}
fn reset(&mut self) {
self.c_index.clear();
self.a_scores.clear();
}
fn name(&self) -> &str {
&self.constraint_ref.name
}
fn is_hard(&self) -> bool {
self.is_hard
}
fn constraint_ref(&self) -> ConstraintRef {
self.constraint_ref.clone()
}
}
impl<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc: Score> std::fmt::Debug
for FlattenedBiConstraint<S, A, B, C, K, CK, EA, EB, KA, KB, Flatten, CKeyFn, ALookup, F, W, Sc>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlattenedBiConstraint")
.field("name", &self.constraint_ref.name)
.field("impact_type", &self.impact_type)
.field("c_index_size", &self.c_index.len())
.finish()
}
}