use std::fmt::Debug;
use std::marker::PhantomData;
use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};
use crate::api::analysis::{ConstraintJustification, DetailedConstraintMatch, EntityRef};
use crate::api::constraint_set::IncrementalConstraint;
pub struct IncrementalUniConstraint<S, A, E, F, W, Sc>
where
Sc: Score,
{
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor: E,
filter: F,
weight: W,
is_hard: bool,
expected_descriptor: Option<usize>,
_phantom: PhantomData<(fn() -> S, fn() -> A, fn() -> Sc)>,
}
impl<S, A, E, F, W, Sc> IncrementalUniConstraint<S, A, E, F, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Send + Sync + 'static,
E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
F: Fn(&S, &A) -> bool + Send + Sync,
W: Fn(&A) -> Sc + Send + Sync,
Sc: Score,
{
pub fn new(
constraint_ref: ConstraintRef,
impact_type: ImpactType,
extractor: E,
filter: F,
weight: W,
is_hard: bool,
) -> Self {
Self {
constraint_ref,
impact_type,
extractor,
filter,
weight,
is_hard,
expected_descriptor: None,
_phantom: PhantomData,
}
}
pub fn with_descriptor(mut self, descriptor_index: usize) -> Self {
self.expected_descriptor = Some(descriptor_index);
self
}
#[inline]
fn matches(&self, solution: &S, entity: &A) -> bool {
(self.filter)(solution, entity)
}
#[inline]
fn compute_delta(&self, entity: &A) -> Sc {
let base = (self.weight)(entity);
match self.impact_type {
ImpactType::Penalty => -base,
ImpactType::Reward => base,
}
}
#[inline]
fn reverse_delta(&self, entity: &A) -> Sc {
let base = (self.weight)(entity);
match self.impact_type {
ImpactType::Penalty => base,
ImpactType::Reward => -base,
}
}
}
impl<S, A, E, F, W, Sc> IncrementalConstraint<S, Sc> for IncrementalUniConstraint<S, A, E, F, W, Sc>
where
S: Send + Sync + 'static,
A: Clone + Debug + Send + Sync + 'static,
E: crate::stream::collection_extract::CollectionExtract<S, Item = A>,
F: Fn(&S, &A) -> bool + Send + Sync,
W: Fn(&A) -> Sc + Send + Sync,
Sc: Score,
{
fn evaluate(&self, solution: &S) -> Sc {
let entities = self.extractor.extract(solution);
let mut total = Sc::zero();
for entity in entities {
if self.matches(solution, entity) {
total = total + self.compute_delta(entity);
}
}
total
}
fn match_count(&self, solution: &S) -> usize {
let entities = self.extractor.extract(solution);
entities
.iter()
.filter(|e| self.matches(solution, e))
.count()
}
fn initialize(&mut self, solution: &S) -> Sc {
self.evaluate(solution)
}
fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
if let Some(expected) = self.expected_descriptor {
if descriptor_index != expected {
return Sc::zero();
}
}
let entities = self.extractor.extract(solution);
if entity_index >= entities.len() {
return Sc::zero();
}
let entity = &entities[entity_index];
if self.matches(solution, entity) {
self.compute_delta(entity)
} else {
Sc::zero()
}
}
fn on_retract(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
if let Some(expected) = self.expected_descriptor {
if descriptor_index != expected {
return Sc::zero();
}
}
let entities = self.extractor.extract(solution);
if entity_index >= entities.len() {
return Sc::zero();
}
let entity = &entities[entity_index];
if self.matches(solution, entity) {
self.reverse_delta(entity)
} else {
Sc::zero()
}
}
fn reset(&mut self) {
}
fn name(&self) -> &str {
&self.constraint_ref.name
}
fn is_hard(&self) -> bool {
self.is_hard
}
fn constraint_ref(&self) -> ConstraintRef {
self.constraint_ref.clone()
}
fn get_matches(&self, solution: &S) -> Vec<DetailedConstraintMatch<Sc>> {
let entities = self.extractor.extract(solution);
let cref = self.constraint_ref.clone();
entities
.iter()
.filter(|e| self.matches(solution, e))
.map(|entity| {
let entity_ref = EntityRef::new(entity);
let justification = ConstraintJustification::new(vec![entity_ref]);
DetailedConstraintMatch::new(
cref.clone(),
self.compute_delta(entity),
justification,
)
})
.collect()
}
fn weight(&self) -> Sc {
Sc::zero()
}
}
impl<S, A, E, F, W, Sc: Score> std::fmt::Debug for IncrementalUniConstraint<S, A, E, F, W, Sc> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IncrementalUniConstraint")
.field("name", &self.constraint_ref.name)
.field("impact_type", &self.impact_type)
.finish()
}
}