solverforge-scoring 0.10.0

Incremental constraint scoring for SolverForge
Documentation
use std::collections::HashMap;
use std::marker::PhantomData;

use solverforge_core::score::Score;
use solverforge_core::{ConstraintRef, ImpactType};

use crate::api::constraint_set::IncrementalConstraint;
use crate::stream::filter::UniFilter;
use crate::stream::ProjectedSource;

pub struct ProjectedUniConstraint<S, Out, Src, F, W, Sc>
where
    Sc: Score,
{
    constraint_ref: ConstraintRef,
    impact_type: ImpactType,
    source: Src,
    filter: F,
    weight: W,
    is_hard: bool,
    entity_contributions: HashMap<(usize, usize), Vec<Sc>>,
    _phantom: PhantomData<(fn() -> S, fn() -> Out)>,
}

impl<S, Out, Src, F, W, Sc> ProjectedUniConstraint<S, Out, Src, F, W, Sc>
where
    S: Send + Sync + 'static,
    Out: Clone + Send + Sync + 'static,
    Src: ProjectedSource<S, Out>,
    F: UniFilter<S, Out>,
    W: Fn(&Out) -> Sc + Send + Sync,
    Sc: Score + 'static,
{
    pub fn new(
        constraint_ref: ConstraintRef,
        impact_type: ImpactType,
        source: Src,
        filter: F,
        weight: W,
        is_hard: bool,
    ) -> Self {
        Self {
            constraint_ref,
            impact_type,
            source,
            filter,
            weight,
            is_hard,
            entity_contributions: HashMap::new(),
            _phantom: PhantomData,
        }
    }

    fn compute_score(&self, output: &Out) -> Sc {
        let base = (self.weight)(output);
        match self.impact_type {
            ImpactType::Penalty => -base,
            ImpactType::Reward => base,
        }
    }

    fn insert_entity_outputs(&mut self, solution: &S, slot: usize, entity_index: usize) -> Sc {
        let mut total = Sc::zero();
        let mut contributions = Vec::new();
        let source = &self.source;
        let filter = &self.filter;
        let weight = &self.weight;
        let impact = self.impact_type;
        source.collect_entity(solution, slot, entity_index, |_, output| {
            if !filter.test(solution, &output) {
                return;
            }
            let base = weight(&output);
            let contribution = match impact {
                ImpactType::Penalty => -base,
                ImpactType::Reward => base,
            };
            total = total + contribution;
            contributions.push(contribution);
        });
        self.entity_contributions
            .insert((slot, entity_index), contributions);
        total
    }

    fn retract_entity_outputs(&mut self, slot: usize, entity_index: usize) -> Sc {
        self.entity_contributions
            .remove(&(slot, entity_index))
            .unwrap_or_default()
            .into_iter()
            .fold(Sc::zero(), |total, contribution| total - contribution)
    }

    fn localized_slots(&self, descriptor_index: usize) -> Vec<usize> {
        let mut slots = Vec::new();
        for slot in 0..self.source.source_count() {
            if self
                .source
                .change_source(slot)
                .assert_localizes(descriptor_index, &self.constraint_ref.name)
            {
                slots.push(slot);
            }
        }
        slots
    }
}

impl<S, Out, Src, F, W, Sc> IncrementalConstraint<S, Sc>
    for ProjectedUniConstraint<S, Out, Src, F, W, Sc>
where
    S: Send + Sync + 'static,
    Out: Clone + Send + Sync + 'static,
    Src: ProjectedSource<S, Out>,
    F: UniFilter<S, Out>,
    W: Fn(&Out) -> Sc + Send + Sync,
    Sc: Score + 'static,
{
    fn evaluate(&self, solution: &S) -> Sc {
        let mut total = Sc::zero();
        self.source.collect_all(solution, |_, output| {
            if self.filter.test(solution, &output) {
                total = total + self.compute_score(&output);
            }
        });
        total
    }

    fn match_count(&self, solution: &S) -> usize {
        let mut count = 0;
        self.source.collect_all(solution, |_, output| {
            if self.filter.test(solution, &output) {
                count += 1;
            }
        });
        count
    }

    fn initialize(&mut self, solution: &S) -> Sc {
        self.reset();
        let mut total = Sc::zero();
        let source = &self.source;
        let filter = &self.filter;
        let weight = &self.weight;
        let impact = self.impact_type;
        let entity_contributions = &mut self.entity_contributions;
        source.collect_all(solution, |coordinate, output| {
            if !filter.test(solution, &output) {
                return;
            }
            let base = weight(&output);
            let contribution = match impact {
                ImpactType::Penalty => -base,
                ImpactType::Reward => base,
            };
            let mut contributions = entity_contributions
                .remove(&(coordinate.source_slot, coordinate.entity_index))
                .unwrap_or_default();
            total = total + contribution;
            contributions.push(contribution);
            entity_contributions.insert(
                (coordinate.source_slot, coordinate.entity_index),
                contributions,
            );
        });
        total
    }

    fn on_insert(&mut self, solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
        let mut total = Sc::zero();
        for slot in self.localized_slots(descriptor_index) {
            total = total + self.insert_entity_outputs(solution, slot, entity_index);
        }
        total
    }

    fn on_retract(&mut self, _solution: &S, entity_index: usize, descriptor_index: usize) -> Sc {
        let mut total = Sc::zero();
        for slot in self.localized_slots(descriptor_index) {
            total = total + self.retract_entity_outputs(slot, entity_index);
        }
        total
    }

    fn reset(&mut self) {
        self.entity_contributions.clear();
    }

    fn name(&self) -> &str {
        &self.constraint_ref.name
    }

    fn is_hard(&self) -> bool {
        self.is_hard
    }

    fn constraint_ref(&self) -> ConstraintRef {
        self.constraint_ref.clone()
    }
}