use std::collections::HashSet;
use std::hash::Hash;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use crate::api::constraint_set::IncrementalConstraint;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExistenceMode {
Exists,
NotExists,
}
pub struct IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
where
Sc: Score,
{
constraint_ref: ConstraintRef,
impact_type: ImpactType,
mode: ExistenceMode,
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
filter_a: FA,
weight: W,
is_hard: bool,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> B, fn() -> K, fn() -> Sc)>,
}
impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
where
S: 'static,
A: Clone + 'static,
B: Clone + 'static,
K: Eq + Hash + Clone,
EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
EB: Fn(&S) -> Vec<B>,
KA: Fn(&A) -> K,
KB: Fn(&B) -> K,
FA: Fn(&S, &A) -> bool,
W: Fn(&A) -> Sc,
Sc: Score,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
constraint_ref: ConstraintRef,
impact_type: ImpactType,
mode: ExistenceMode,
extractor_a: EA,
extractor_b: EB,
key_a: KA,
key_b: KB,
filter_a: FA,
weight: W,
is_hard: bool,
) -> Self {
Self {
constraint_ref,
impact_type,
mode,
extractor_a,
extractor_b,
key_a,
key_b,
filter_a,
weight,
is_hard,
_phantom: PhantomData,
}
}
#[inline]
fn compute_score(&self, a: &A) -> Sc {
let base = (self.weight)(a);
match self.impact_type {
ImpactType::Penalty => -base,
ImpactType::Reward => base,
}
}
fn build_b_keys(&self, solution: &S) -> HashSet<K> {
let entities_b = (self.extractor_b)(solution);
entities_b.iter().map(|b| (self.key_b)(b)).collect()
}
fn matches_existence(&self, a: &A, b_keys: &HashSet<K>) -> bool {
let key = (self.key_a)(a);
let exists = b_keys.contains(&key);
match self.mode {
ExistenceMode::Exists => exists,
ExistenceMode::NotExists => !exists,
}
}
}
impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc> IncrementalConstraint<S, Sc>
for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
B: Clone + Send + Sync + 'static,
K: Eq + Hash + Clone + Send + Sync,
EA: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
EB: Fn(&S) -> Vec<B> + Send + Sync,
KA: Fn(&A) -> K + Send + Sync,
KB: Fn(&B) -> K + Send + Sync,
FA: Fn(&S, &A) -> bool + Send + Sync,
W: Fn(&A) -> Sc + Send + Sync,
Sc: Score,
{
fn evaluate(&self, solution: &S) -> Sc {
let entities_a = self.extractor_a.extract(solution);
let b_keys = self.build_b_keys(solution);
let mut total = Sc::zero();
for a in entities_a {
if (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys) {
total = total + self.compute_score(a);
}
}
total
}
fn match_count(&self, solution: &S) -> usize {
let entities_a = self.extractor_a.extract(solution);
let b_keys = self.build_b_keys(solution);
entities_a
.iter()
.filter(|a| (self.filter_a)(solution, a) && self.matches_existence(a, &b_keys))
.count()
}
fn initialize(&mut self, solution: &S) -> Sc {
self.evaluate(solution)
}
fn on_insert(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
let entities_a = self.extractor_a.extract(solution);
if entity_index >= entities_a.len() {
return Sc::zero();
}
let a = &entities_a[entity_index];
if !(self.filter_a)(solution, a) {
return Sc::zero();
}
let b_keys = self.build_b_keys(solution);
if self.matches_existence(a, &b_keys) {
self.compute_score(a)
} else {
Sc::zero()
}
}
fn on_retract(&mut self, solution: &S, entity_index: usize, _descriptor_index: usize) -> Sc {
let entities_a = self.extractor_a.extract(solution);
if entity_index >= entities_a.len() {
return Sc::zero();
}
let a = &entities_a[entity_index];
if !(self.filter_a)(solution, a) {
return Sc::zero();
}
let b_keys = self.build_b_keys(solution);
if self.matches_existence(a, &b_keys) {
-self.compute_score(a)
} else {
Sc::zero()
}
}
fn reset(&mut self) {
}
fn name(&self) -> &str {
&self.constraint_ref.name
}
fn is_hard(&self) -> bool {
self.is_hard
}
fn constraint_ref(&self) -> ConstraintRef {
self.constraint_ref.clone()
}
}
impl<S, A, B, K, EA, EB, KA, KB, FA, W, Sc: Score> std::fmt::Debug
for IfExistsUniConstraint<S, A, B, K, EA, EB, KA, KB, FA, W, Sc>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IfExistsUniConstraint")
.field("name", &self.constraint_ref.name)
.field("impact_type", &self.impact_type)
.field("mode", &self.mode)
.finish()
}
}